diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 997a1069401..4068cfdd0c8 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,6 +25,7 @@ Changelog - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. - Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress. +- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default). **Bug Fixes** diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0626d0a8fd5..21e5ba0638d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -679,6 +679,34 @@ def _process_quantized_modules( raise AssertionError( f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}" ) from e + elif isinstance(sub_module, nn.Embedding) and hasattr(sub_module, "weight_quantizer"): + # Quantized nn.Embedding: pack the embedding table the same way as Linear + # weights so downstream loaders see the NVFP4/FP8/INT-packed bytes + scales. + # Skip packing when the embedding's weight is tied to another module + # (e.g. tied_word_embeddings → lm_head): _export_quantized_weight reassigns + # the .weight attribute to a new uint8 Parameter, which severs the Python- + # level tie and leaves the other module pointing at a stale float Parameter. + tied_to = [ + other_name + for other_name, other_module in model.named_modules() + if other_module is not sub_module + and getattr(other_module, "weight", None) is sub_module.weight + ] + if tied_to: + warnings.warn( + f"Skipping quantized weight packing for embedding '{name}': its " + f"weight Parameter is shared with {tied_to} (weight tying). Packing " + "would break the tie and produce stale weights in the tied module(s). " + "The embedding will be exported as its fake-quantized float weight." + ) + else: + try: + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + except AssertionError as e: + raise AssertionError( + f"Failed to export embedding '{name}' (type={type(sub_module).__name__}): {e}" + ) from e elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ diff --git a/modelopt/torch/quantization/nn/__init__.py b/modelopt/torch/quantization/nn/__init__.py index af9490c8311..2e6bc64054e 100644 --- a/modelopt/torch/quantization/nn/__init__.py +++ b/modelopt/torch/quantization/nn/__init__.py @@ -18,6 +18,7 @@ from .modules.quant_activations import * from .modules.quant_batchnorm import * from .modules.quant_conv import * +from .modules.quant_embedding import * from .modules.quant_instancenorm import * from .modules.quant_layernorm import * from .modules.quant_linear import * diff --git a/modelopt/torch/quantization/nn/modules/quant_embedding.py b/modelopt/torch/quantization/nn/modules/quant_embedding.py new file mode 100644 index 00000000000..91e9f754c7d --- /dev/null +++ b/modelopt/torch/quantization/nn/modules/quant_embedding.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Quantized Embedding.""" + +import contextlib + +import torch +import torch.nn as nn + +from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR +from ...utils import is_torch_export_mode +from .quant_module import QuantModule, QuantModuleRegistry +from .tensor_quantizer import SequentialQuantizer, TensorQuantizer + +__all__ = ["QuantEmbedding"] + + +_INPUT_QUANTIZER_ERR = ( + "Cannot configure input_quantizer on a quantized nn.Embedding: the input is integer " + "indices and cannot be fake-quantized. Configure weight_quantizer (and optionally " + "output_quantizer) instead." +) + + +class _UnsettableInputQuantizer(TensorQuantizer): + """TensorQuantizer slot for nn.Embedding.input — present but not enable-able. + + Embedding inputs are integer indices that cannot be fake-quantized. The attribute + is kept so introspection code (export, calibration helpers) can find it. + + Wildcard configs (e.g. the default ``QuantizeConfig`` ``"*"`` rule or + ``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently, then the + quantizer is force-disabled — wildcards don't really mean "enable embedding + input quant", they mean "enable input quant in general". Direct, explicit + attempts (calling ``enable``/``enable_quant``/``enable_calib``) raise loudly. + """ + + def enable(self): + """Disallowed for embedding inputs.""" + raise RuntimeError(_INPUT_QUANTIZER_ERR) + + def enable_quant(self): + """Disallowed for embedding inputs.""" + raise RuntimeError(_INPUT_QUANTIZER_ERR) + + def enable_calib(self): + """Disallowed for embedding inputs.""" + raise RuntimeError(_INPUT_QUANTIZER_ERR) + + def set_from_attribute_config(self, attribute_cfg): + """Apply the config like any quantizer, then force-disable us. + + This absorbs wildcard configs from stock recipes without raising. The + quantizer's other attributes (``num_bits``, ``axis``, etc.) take on the + config values for introspection, but ``_disabled`` is forced back to + ``True`` so forward is always a no-op. + """ + super().set_from_attribute_config(attribute_cfg) + self._disabled = True + + +@QuantModuleRegistry.register({nn.Embedding: "nn.Embedding"}) +class _QuantEmbedding(QuantModule): + """Quantized version of ``nn.Embedding``. + + The literal input to ``nn.Embedding`` is integer indices, which cannot be + fake-quantized. The ``input_quantizer`` attribute is kept (for symmetry with + other quant modules and for introspection by export/calibration code) but is + permanently disabled — see ``_UnsettableInputQuantizer``. Only the embedding + table (weight) and the lookup output (an activation feeding downstream layers) + are quantizable. + + Quantizer roles: + - ``weight_quantizer``: quantizes the embedding table (``self.weight``). + - ``input_quantizer``: permanently disabled placeholder — direct + ``enable*()`` calls raise; configs that target it are absorbed and the + quantizer is force-disabled. + - ``output_quantizer``: optional activation quantizer for the lookup output, + disabled by default. + """ + + weight_quantizer: TensorQuantizer | SequentialQuantizer + input_quantizer: _UnsettableInputQuantizer + output_quantizer: TensorQuantizer + _enable_weight_quantization: bool + default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR + default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR + default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR + + @contextlib.contextmanager + def quantize_weight(self): + """Context in which ``self.weight`` is quantized via the dynamic attribute.""" + self._enable_weight_quantization = True + try: + yield + finally: + self._enable_weight_quantization = False + + @staticmethod + def _get_quantized_weight(module: "_QuantEmbedding", weight: torch.Tensor) -> torch.Tensor: + if module._enable_weight_quantization or is_torch_export_mode(): + return module.weight_quantizer(weight) + return weight + + def _setup(self): + """Register weight, (locked) input, and output quantizers.""" + self._register_temp_attribute( + "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight) + ) + # Build the input quantizer disabled. _UnsettableInputQuantizer's mutators raise, + # so we disable it once at construction via direct attribute assignment. + input_quantizer = _UnsettableInputQuantizer(self.default_quant_desc_input) + input_quantizer._disabled = True + self._register_temp_attribute("input_quantizer", input_quantizer) + self._register_temp_attribute( + "output_quantizer", TensorQuantizer(self.default_quant_desc_output) + ) + self.output_quantizer.disable() + self._register_temp_attribute("_enable_weight_quantization", False) + self._register_dynamic_attribute("weight", self._get_quantized_weight) + + def forward(self, input, *args, **kwargs): + """Quantize the embedding table, look up, then optionally quantize the output. + + ``input_quantizer`` is intentionally never applied — embedding inputs are + integer indices. ``_UnsettableInputQuantizer.set_from_attribute_config`` + keeps that quantizer disabled regardless of what configs target it, so we + rely on that invariant rather than a runtime check here. + """ + if is_torch_export_mode(): + # quantize_weight()'s attribute write is not allowed under torch.export; + # weight quantization is still applied inline via _get_quantized_weight's + # is_torch_export_mode() branch. Apply output_quantizer in this path too + # so users who opt into output activation quantization don't silently + # lose it during export — matches QuantInputBase.forward's behavior. + output = super().forward(input, *args, **kwargs) + else: + with self.quantize_weight(): + output = super().forward(input, *args, **kwargs) + return self.output_quantizer(output) + + +# Public alias consistent with quant_linear / quant_conv naming. +QuantEmbedding = _QuantEmbedding diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index cea3d4260e4..15c6504011b 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -297,6 +297,11 @@ def is_quantized_linear(module): """Check if a module is a quantized linear module.""" from ..nn import QuantModule, TensorQuantizer + # Embedding has a 2D weight but is not a GEMM op, so calibration passes that operate + # on linear activations (AWQ, SmoothQuant, SVDQuant) must skip it. + if isinstance(module, nn.Embedding): + return False + return ( isinstance(module, QuantModule) and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer) diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index 1508f942776..4fddfdbbc6a 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -48,3 +48,6 @@ - parent_class: 'nn.LeakyReLU' quantizer_name: '*' enable: false + - parent_class: 'nn.Embedding' + quantizer_name: '*' + enable: false diff --git a/tests/unit/torch/quantization/test_quant_embedding.py b/tests/unit/torch/quantization/test_quant_embedding.py new file mode 100644 index 00000000000..af3931b1892 --- /dev/null +++ b/tests/unit/torch/quantization/test_quant_embedding.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests of QuantEmbedding module.""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import _process_quantized_modules +from modelopt.torch.quantization import tensor_quant +from modelopt.torch.quantization import utils as quant_utils +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.conversion import set_quantizer_attributes_partial +from modelopt.torch.quantization.nn import QuantModuleRegistry +from modelopt.torch.quantization.nn.modules.quant_embedding import _UnsettableInputQuantizer +from modelopt.torch.quantization.utils import quantizer_attr_names + +VOCAB_SIZE = 16 +EMBED_DIM = 32 # multiple of the NVFP4 block size (16) so export tests can pack + + +def _make_quant_embedding(**kwargs) -> nn.Module: + """Build an nn.Embedding and convert it through QuantModuleRegistry.""" + return QuantModuleRegistry.convert(nn.Embedding(VOCAB_SIZE, EMBED_DIM, **kwargs)) + + +class TestQuantEmbedding: + """Forward-path behavior of the registered QuantEmbedding wrapper.""" + + def test_default_state_and_no_quant(self): + """Default state: input quant locked-disabled, output quant disabled, weight quant on; + with weight quant also off the wrapper matches plain F.embedding.""" + qemb = _make_quant_embedding() + assert isinstance(qemb.input_quantizer, _UnsettableInputQuantizer) + assert not qemb.input_quantizer.is_enabled + assert not qemb.output_quantizer.is_enabled + assert qemb.weight_quantizer.is_enabled + + qemb.weight_quantizer.disable() + ids = torch.randint(0, VOCAB_SIZE, (4, 6)) + assert torch.allclose(qemb(ids), F.embedding(ids, qemb.weight), rtol=0, atol=0) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_weight_fake_quant(self, axis): + """Per-tensor (axis=None) and per-row (axis=0) weight fake quant match the manual ref.""" + qemb = _make_quant_embedding() + set_quantizer_attributes_partial( + qemb, "*weight_quantizer", QuantizerAttributeConfig(axis=axis).model_dump() + ) + + ids = torch.randint(0, VOCAB_SIZE, (4, 6)) + weight = qemb.weight.detach().clone() + amax = ( + torch.max(torch.abs(weight)) + if axis is None + else quant_utils.reduce_amax(weight, axis=1, keepdims=True) + ) + ref = F.embedding(ids, tensor_quant.fake_tensor_quant(weight, amax)) + assert torch.allclose(qemb(ids), ref, rtol=0, atol=0) + + def test_output_quantizer_applied_when_enabled(self): + """Enabling output_quantizer makes forward equivalent to applying it to the lookup.""" + qemb = _make_quant_embedding() + qemb.weight_quantizer.disable() + qemb.output_quantizer.enable() + ids = torch.randint(0, VOCAB_SIZE, (4, 6)) + with torch.no_grad(): + qemb(ids) # calibrate + + ref = qemb.output_quantizer(F.embedding(ids, qemb.weight)) + assert torch.allclose(qemb(ids), ref, rtol=0, atol=0) + + @pytest.mark.parametrize("method", ["enable", "enable_quant", "enable_calib"]) + def test_input_quantizer_mutators_raise(self, method): + """Each public enable/enable_quant/enable_calib API on input_quantizer raises.""" + qemb = _make_quant_embedding() + with pytest.raises(RuntimeError, match="nn.Embedding"): + getattr(qemb.input_quantizer, method)() + + def test_wildcard_config_keeps_input_quantizer_disabled(self): + """set_from_attribute_config absorbs any cfg but force-disables input_quantizer. + + Stock recipes' ``*input_quantizer`` wildcard (and the default ``QuantizeConfig`` + ``"*"`` rule) target every quantizer including the embedding's input slot. + The quantizer must end up disabled regardless of what the cfg said. + """ + qemb = _make_quant_embedding() + set_quantizer_attributes_partial( + qemb, + "*input_quantizer", + QuantizerAttributeConfig(num_bits=8, axis=None).model_dump(), + ) + assert not qemb.input_quantizer.is_enabled + # Forward still works — input_quantizer is disabled and never applied. + qemb(torch.randint(0, VOCAB_SIZE, (4, 6))) + + +def _embedding_nvfp4_cfg() -> dict: + """Stock-NVFP4-style cfg that opts the embedding's weight quantizer in.""" + nvfp4 = { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + } + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "parent_class": "nn.Embedding", + "quantizer_name": "*weight_quantizer", + "cfg": dict(nvfp4), + }, + ], + "algorithm": "max", + } + + +class _EmbeddingOnly(nn.Module): + """Single-embedding wrapper exposing forward + named_modules iteration.""" + + def __init__(self): + """Build the lone embedding submodule.""" + super().__init__() + self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM) + + def forward(self, ids): + """Look up embeddings for the given token IDs.""" + return self.embedding(ids) + + +class _TiedEmbeddingLM(nn.Module): + """Embedding + Linear lm_head with tied weights (lm_head.weight is embedding.weight).""" + + def __init__(self): + """Build embedding + lm_head and tie their weight Parameters.""" + super().__init__() + self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM) + self.lm_head = nn.Linear(EMBED_DIM, VOCAB_SIZE, bias=False) + self.lm_head.weight = self.embedding.weight # Python-level tie + + def forward(self, ids): + """Embed then project to vocab logits with the tied weight.""" + return self.lm_head(self.embedding(ids)) + + +class TestQuantEmbeddingExport: + """Export-path coverage: weight packing and tied-weight guard.""" + + def test_quantized_weight_is_packed_and_scales_registered(self): + """End-to-end: _process_quantized_modules packs the embedding weight and + registers ``weight_scale`` + ``weight_scale_2`` buffers.""" + model = _EmbeddingOnly() + model = mtq.quantize( + model, _embedding_nvfp4_cfg(), lambda m: m(torch.randint(0, VOCAB_SIZE, (2, 4))) + ) + _process_quantized_modules(model, dtype=torch.float16) + + attrs = quantizer_attr_names("weight") + assert model.embedding.weight.dtype == torch.uint8 + assert hasattr(model.embedding, attrs.weight_scale) + assert hasattr(model.embedding, attrs.weight_scale_2) + # input_scale is not registered (input_quantizer is permanently disabled). + assert not hasattr(model.embedding, attrs.input_scale) + + def test_tied_embedding_export_skips_packing(self): + """When the embedding weight is shared with lm_head, packing is skipped + with a warning so the tie survives the export.""" + model = _TiedEmbeddingLM() + assert model.lm_head.weight is model.embedding.weight # sanity + + model = mtq.quantize( + model, _embedding_nvfp4_cfg(), lambda m: m(torch.randint(0, VOCAB_SIZE, (2, 4))) + ) + orig_dtype = model.embedding.weight.dtype + with pytest.warns(UserWarning, match="tied"): + _process_quantized_modules(model, dtype=torch.float16) + + # Weight Parameter unchanged (not packed to uint8) and still tied. + assert model.embedding.weight.dtype == orig_dtype + assert model.lm_head.weight is model.embedding.weight