[OMNIML-4730] Support quantized nn.Embedding#1495
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (7)
✅ Files skipped from review due to trivial changes (2)
🚧 Files skipped from review as they are similar to previous changes (4)
📝 WalkthroughWalkthroughAdds QuantEmbedding (quantized nn.Embedding) with gated weight quantization, a permanently disabled input-quantizer, optional output quantization, export packing support (with tied-weight skip), calibration exclusion, default-disabled config entry, unit tests, and a changelog entry. ChangesQuantized Embedding Support
sequenceDiagram
participant Client
participant QuantEmbedding
participant WeightQuantizer
participant OutputQuantizer
Client->>QuantEmbedding: forward(input_indices)
QuantEmbedding->>QuantEmbedding: ensure input_quantizer disabled
QuantEmbedding->>WeightQuantizer: get quantized weight (if enabled or export)
alt quantized weight returned
QuantEmbedding->>QuantEmbedding: lookup with quantized weight
else raw weight used
QuantEmbedding->>QuantEmbedding: lookup with raw weight
end
QuantEmbedding->>OutputQuantizer: apply output quantizer if enabled
QuantEmbedding-->>Client: return embeddings
Estimated code review effort: Possibly related PRs:
Suggested reviewers:
🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Quantized nn.Embedding support cleanly mirrors the _QuantLinear/QuantLinearConvBase pattern (dynamic-attribute weight, quantize_weight context, _register_temp_attribute), so design-wise it slots in well — no second composition system. The wildcard-tolerance via _UnsettableInputQuantizer is unusual but justified: stock recipes apply *input_quantizer enables, and the YAML parent_class: nn.Embedding, *, enable: false rule is appended last in every preset that uses _default_disabled_quantizer_cfg, so the disabled state is restored before forward. Unit tests cover the lock semantics and weight quant against fake_tensor_quant reference.
Three things worth a maintainer look before approving:
-
output_quantizeris silently bypassed undertorch.export._QuantEmbedding.forwarddoesif is_torch_export_mode(): return super().forward(...)— that path never callsself.output_quantizer(output).QuantLinearConvBase/QuantInputBaseboth keep the output_quantizer in the export path. If a user opts intooutput_quantizerand thentorch.exports, they'll lose it without warning. Probably harmless today (output_quantizer is off by default) but it's an inconsistency. -
Tied embeddings (
tied_word_embeddings=True) likely break on export._export_quantized_weightdoessetattr(sub_module, weight_name, nn.Parameter(quantized_weight, ...)), replacingembedding.weightwith a new Parameter holding packed uint8 bytes. Iflm_head.weightwas tied to the same Parameter, the tie is severed andlm_headkeeps a stale float weight;postprocess_state_dict's tied-weight dedup will then drop one of the keys from the safetensors output. The PR description's example uses an embedding-only model, which sidesteps this — but in real LLMs (Llama/Qwen with tied embeddings) this needs at least a guard or explicit warning. -
No export-path test. All new tests are pure forward tests; the new
_process_quantized_modulesbranch routingnn.Embeddingthrough_export_quantized_weighthas no coverage. Given (2), an export round-trip test on a tiny tied-embedding model would catch the issue. The PR description says it was verified manually on an embedding-only model — that's exactly the case that doesn't exercise the tying path.
Smaller/optional: the _UnsettableInputQuantizer.enable* overrides catch user-facing direct calls, but set_from_attribute_config({"enable": True}) writes _disabled directly via setattr, so the only real defense is the runtime check in forward. The current docstring already explains this; just confirm the runtime guard is the load-bearing one and the method overrides are belt-and-suspenders.
Register nn.Embedding in QuantModuleRegistry so the embedding table and the lookup activations participate in quantization. The literal input is integer indices, so input_quantizer is a non-configurable placeholder that raises on direct enable*() calls and at forward-time if its _disabled flag is flipped — wildcard configs (e.g. NVFP4_DEFAULT_CFG's *input_quantizer) are accepted silently so the stock deny-all → enable wildcards → opt-out pattern continues to work, and the opt-out is installed by default (parent_class: nn.Embedding in default_disabled_quantizers.yaml). export_hf_checkpoint packs quantized embedding weights through the same path as Linear layers. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
- Apply output_quantizer in the torch.export branch of _QuantEmbedding.forward so users who opt into output activation quantization don't silently lose it during export. Matches QuantInputBase.forward's behavior. - Detect Python-level weight tying (e.g. tied_word_embeddings → lm_head) in _process_quantized_modules and skip packing the embedding when the .weight Parameter is shared, with a UserWarning. Packing would otherwise reassign the embedding's .weight to a new uint8 Parameter, severing the tie and leaving the tied module pointing at a stale float Parameter. - Add export-path tests covering the normal pack flow (weight → uint8 + weight_scale + weight_scale_2 buffers) and the tied-embedding skip path (weight unchanged, warning raised, tie preserved). Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
The previous design raised in _QuantEmbedding.forward whenever
input_quantizer.is_enabled, on the theory that any non-disable config was
an explicit user mistake. That assumption was wrong for wildcard configs:
the default QuantizeConfig is just [{"quantizer_name": "*", "cfg":
{"num_bits": 8, ...}}] (no embedding opt-out), so the wildcard enables
embed_tokens.input_quantizer for tiny Llama-style tests and the forward
guard fires — breaking test_peft_save_load and test_transformers_save_load.
Switch _UnsettableInputQuantizer.set_from_attribute_config to absorb the
incoming config like a normal quantizer, then force _disabled = True at
the end. The "throw on explicit set" semantics are preserved via the
.enable / .enable_quant / .enable_calib overrides, which catch the direct
mistakes users would actually make. The forward-time guard (and the
corresponding test) are removed since the invariant is now maintained at
the configure step.
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
4c4db31 to
a932284
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: all three critical comments from the previous round are addressed.
- output_quantizer under torch.export:
forward()now restructured soreturn self.output_quantizer(output)lives outside the if/else — applied in both the export and non-export paths, matchingQuantInputBase.forward. ✅ - Tied-embedding export safety:
_process_quantized_modulesnow scansmodel.named_modules()for any other module whoseweightis the same Parameter object as the embedding's, and skips packing with aUserWarningwhen a tie is found, preservingtied_word_embeddingssemantics. ✅ - Export-path tests: New
TestQuantEmbeddingExportcovers both the success path (uint8-packed weight +weight_scale/weight_scale_2registered, noinput_scale) and the tied-weight guard (warning + tie preserved). ✅
Design-wise this isn't a new abstraction — quant_embedding.py slots in next to quant_linear.py/quant_conv.py/etc., reusing QuantModule, QuantModuleRegistry, _register_temp_attribute, _register_dynamic_attribute, and the quantize_weight context manager pattern. Previous reviewer explicitly endorsed the design.
Minor docs drift: PR body still says "forward() raises if _disabled is flipped via any back door", but the current forward has no such runtime check. That's fine in practice — nn.Embedding.forward never invokes input_quantizer, so a back-door enable can't cause an integer-index tensor to hit the fake-quant path. Worth tightening the PR description but not a blocker.
Complex PR: spans 7 directories (≥ 5). Looping in a human for approval.
| # so we disable it once at construction via direct attribute assignment. | ||
| input_quantizer = _UnsettableInputQuantizer(self.default_quant_desc_input) | ||
| input_quantizer._disabled = True | ||
| self._register_temp_attribute("input_quantizer", input_quantizer) |
There was a problem hiding this comment.
could you help me understand:
why we have an input_quantizer here? Isn't this a weight quantizer only?
|
|
||
| # Embedding has a 2D weight but is not a GEMM op, so calibration passes that operate | ||
| # on linear activations (AWQ, SmoothQuant, SVDQuant) must skip it. | ||
| if isinstance(module, nn.Embedding): |
cjluo-nv
left a comment
There was a problem hiding this comment.
do we plan to run some fake quant evals?
What does this PR do?
Type of change: new feature
Register
nn.EmbeddinginQuantModuleRegistryso the embedding table and lookup activations participate in quantization end-to-end:modelopt/torch/quantization/nn/modules/quant_embedding.pyexposesweight_quantizer(embedding table),output_quantizer(lookup activations, off by default), and aninput_quantizerplaceholder. Embedding inputs are integer indices that cannot be fake-quantized, so directenable()/enable_quant()/enable_calib()calls oninput_quantizerraise, andforward()raises if_disabledis flipped via any back door. Wildcard configs (*input_quantizer) are accepted silently so the stock deny-all → enable-wildcards → opt-out pattern inNVFP4_DEFAULT_CFGand friends still works.default_disabled_quantizers.yamlinstallsparent_class: nn.Embedding, enable: falseso embedding quantization is opt-in and existing model behavior is unchanged.is_quantized_linearincore_utils.pyearly-returnsFalsefornn.Embeddingso AWQ / SmoothQuant / SVDQuant don't treat it as a GEMM op._process_quantized_modulesinunified_export_hf.pyroutes quantizednn.Embeddingmodules through_export_quantized_weight, so the exported checkpoint contains the packed NVFP4 / FP8 / INT bytes plusweight_scale*buffers, exactly like Linear layers.Usage
Testing
tests/unit/torch/quantization/test_quant_embedding.pycover: default quantizer state, no-quant identity, per-tensor and per-row weight fake quant against the manualtensor_quant.fake_tensor_quantreference, output quantizer activation, locked-mutator raises (parametrized overenable/enable_quant/enable_calib), forward-time guard for back-door_disabled = False, and the wildcard-then-opt-out pattern. All 9 cases pass.mtq.quantizewithNVFP4_DEFAULT_CFG+ the embedding opt-in producesembedding.weight (uint8),embedding.weight_scale (float8_e4m3fn),embedding.weight_scale_2 (float32)in the exported safetensors, with"quant_algo": "NVFP4"inhf_quant_config.json.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).parent_class: nn.Embedding, enable: falseindefault_disabled_quantizers.yaml, so existing model behavior is unchanged.CONTRIBUTING.md: N/A/claude reviewafter the PR is up.Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation