Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Changelog
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default).

**Bug Fixes**

Expand Down
28 changes: 28 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,34 @@ def _process_quantized_modules(
raise AssertionError(
f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}"
) from e
elif isinstance(sub_module, nn.Embedding) and hasattr(sub_module, "weight_quantizer"):
# Quantized nn.Embedding: pack the embedding table the same way as Linear
# weights so downstream loaders see the NVFP4/FP8/INT-packed bytes + scales.
# Skip packing when the embedding's weight is tied to another module
# (e.g. tied_word_embeddings → lm_head): _export_quantized_weight reassigns
# the .weight attribute to a new uint8 Parameter, which severs the Python-
# level tie and leaves the other module pointing at a stale float Parameter.
tied_to = [
other_name
for other_name, other_module in model.named_modules()
if other_module is not sub_module
and getattr(other_module, "weight", None) is sub_module.weight
]
if tied_to:
warnings.warn(
f"Skipping quantized weight packing for embedding '{name}': its "
f"weight Parameter is shared with {tied_to} (weight tying). Packing "
"would break the tie and produce stale weights in the tied module(s). "
"The embedding will be exported as its fake-quantized float weight."
)
else:
try:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_quantized_weight(sub_module, dtype)
except AssertionError as e:
raise AssertionError(
f"Failed to export embedding '{name}' (type={type(sub_module).__name__}): {e}"
) from e
elif (
"Llama4TextExperts" in type(sub_module).__name__
or "GptOssExperts" in type(sub_module).__name__
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .modules.quant_activations import *
from .modules.quant_batchnorm import *
from .modules.quant_conv import *
from .modules.quant_embedding import *
from .modules.quant_instancenorm import *
from .modules.quant_layernorm import *
from .modules.quant_linear import *
Expand Down
157 changes: 157 additions & 0 deletions modelopt/torch/quantization/nn/modules/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Quantized Embedding."""

import contextlib

import torch
import torch.nn as nn

from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
from ...utils import is_torch_export_mode
from .quant_module import QuantModule, QuantModuleRegistry
from .tensor_quantizer import SequentialQuantizer, TensorQuantizer

__all__ = ["QuantEmbedding"]


_INPUT_QUANTIZER_ERR = (
"Cannot configure input_quantizer on a quantized nn.Embedding: the input is integer "
"indices and cannot be fake-quantized. Configure weight_quantizer (and optionally "
"output_quantizer) instead."
)


class _UnsettableInputQuantizer(TensorQuantizer):
"""TensorQuantizer slot for nn.Embedding.input — present but not enable-able.

Embedding inputs are integer indices that cannot be fake-quantized. The attribute
is kept so introspection code (export, calibration helpers) can find it.

Wildcard configs (e.g. the default ``QuantizeConfig`` ``"*"`` rule or
``NVFP4_DEFAULT_CFG``'s ``*input_quantizer``) are accepted silently, then the
quantizer is force-disabled — wildcards don't really mean "enable embedding
input quant", they mean "enable input quant in general". Direct, explicit
attempts (calling ``enable``/``enable_quant``/``enable_calib``) raise loudly.
"""

def enable(self):
"""Disallowed for embedding inputs."""
raise RuntimeError(_INPUT_QUANTIZER_ERR)

def enable_quant(self):
"""Disallowed for embedding inputs."""
raise RuntimeError(_INPUT_QUANTIZER_ERR)

def enable_calib(self):
"""Disallowed for embedding inputs."""
raise RuntimeError(_INPUT_QUANTIZER_ERR)

def set_from_attribute_config(self, attribute_cfg):
"""Apply the config like any quantizer, then force-disable us.

This absorbs wildcard configs from stock recipes without raising. The
quantizer's other attributes (``num_bits``, ``axis``, etc.) take on the
config values for introspection, but ``_disabled`` is forced back to
``True`` so forward is always a no-op.
"""
super().set_from_attribute_config(attribute_cfg)
self._disabled = True


@QuantModuleRegistry.register({nn.Embedding: "nn.Embedding"})
class _QuantEmbedding(QuantModule):
"""Quantized version of ``nn.Embedding``.

The literal input to ``nn.Embedding`` is integer indices, which cannot be
fake-quantized. The ``input_quantizer`` attribute is kept (for symmetry with
other quant modules and for introspection by export/calibration code) but is
permanently disabled — see ``_UnsettableInputQuantizer``. Only the embedding
table (weight) and the lookup output (an activation feeding downstream layers)
are quantizable.

Quantizer roles:
- ``weight_quantizer``: quantizes the embedding table (``self.weight``).
- ``input_quantizer``: permanently disabled placeholder — direct
``enable*()`` calls raise; configs that target it are absorbed and the
quantizer is force-disabled.
- ``output_quantizer``: optional activation quantizer for the lookup output,
disabled by default.
"""

weight_quantizer: TensorQuantizer | SequentialQuantizer
input_quantizer: _UnsettableInputQuantizer
output_quantizer: TensorQuantizer
_enable_weight_quantization: bool
default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR

@contextlib.contextmanager
def quantize_weight(self):
"""Context in which ``self.weight`` is quantized via the dynamic attribute."""
self._enable_weight_quantization = True
try:
yield
finally:
self._enable_weight_quantization = False

@staticmethod
def _get_quantized_weight(module: "_QuantEmbedding", weight: torch.Tensor) -> torch.Tensor:
if module._enable_weight_quantization or is_torch_export_mode():
return module.weight_quantizer(weight)
return weight

def _setup(self):
"""Register weight, (locked) input, and output quantizers."""
self._register_temp_attribute(
"weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
)
# Build the input quantizer disabled. _UnsettableInputQuantizer's mutators raise,
# so we disable it once at construction via direct attribute assignment.
input_quantizer = _UnsettableInputQuantizer(self.default_quant_desc_input)
input_quantizer._disabled = True
self._register_temp_attribute("input_quantizer", input_quantizer)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help me understand:

why we have an input_quantizer here? Isn't this a weight quantizer only?

self._register_temp_attribute(
"output_quantizer", TensorQuantizer(self.default_quant_desc_output)
)
self.output_quantizer.disable()
self._register_temp_attribute("_enable_weight_quantization", False)
self._register_dynamic_attribute("weight", self._get_quantized_weight)

def forward(self, input, *args, **kwargs):
"""Quantize the embedding table, look up, then optionally quantize the output.

``input_quantizer`` is intentionally never applied — embedding inputs are
integer indices. ``_UnsettableInputQuantizer.set_from_attribute_config``
keeps that quantizer disabled regardless of what configs target it, so we
rely on that invariant rather than a runtime check here.
"""
if is_torch_export_mode():
# quantize_weight()'s attribute write is not allowed under torch.export;
# weight quantization is still applied inline via _get_quantized_weight's
# is_torch_export_mode() branch. Apply output_quantizer in this path too
# so users who opt into output activation quantization don't silently
# lose it during export — matches QuantInputBase.forward's behavior.
output = super().forward(input, *args, **kwargs)
else:
with self.quantize_weight():
output = super().forward(input, *args, **kwargs)
return self.output_quantizer(output)


# Public alias consistent with quant_linear / quant_conv naming.
QuantEmbedding = _QuantEmbedding
5 changes: 5 additions & 0 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ def is_quantized_linear(module):
"""Check if a module is a quantized linear module."""
from ..nn import QuantModule, TensorQuantizer

# Embedding has a 2D weight but is not a GEMM op, so calibration passes that operate
# on linear activations (AWQ, SmoothQuant, SVDQuant) must skip it.
if isinstance(module, nn.Embedding):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

return False

return (
isinstance(module, QuantModule)
and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@
- parent_class: 'nn.LeakyReLU'
quantizer_name: '*'
enable: false
- parent_class: 'nn.Embedding'
quantizer_name: '*'
enable: false
Loading
Loading