Skip to content

Commit 62bde15

Browse files
jQizhangkevalmorabia97
authored andcommitted
Fix weight-only quantization for TEGroupedMLP (MoE models) (#971)
### What does this PR do? This PR fixes a critical issue where weight-only quantization fails for MoE models utilizing `TEGroupedMLP` (e.g., Qwen3-30B-A3B). #### The Problem: In `TEGroupedMLP`, weights are stored per-expert as `weight0`, `weight1`, ..., `weightN`. During `_QuantTEGroupedLinear._setup`, the standard `self.weight` attribute is deleted. The existing `weight_only_quantize` logic expects to find a `self.weight` associated with the quantizer. Because it couldn't find these "hidden" expert weights, the `weight_quantizer` failed to calibrate, resulting in a missing `_amax` attribute. This leads to the following crash during export/inference: <img width="2792" height="1034" alt="image" src="https://github.com/user-attachments/assets/9e2b1abd-80f4-4b8b-bb95-f8ee7a8f686a" /> ```python File ".../modelopt/torch/quantization/qtensor/nvfp4_tensor.py", line 59, in get_weights_scaling_factor_2_from_quantizer assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute amax" ``` #### The Solution: 1. **Calibration Interface:** Introduced `iter_weights_for_calibration` in the `QuantModule` base class. 2. **MoE Support:** Overrode this method in `_QuantTEGroupedLinear` to yield all per-expert weights (`weight0`...`weightN`) that share the same quantizer. This ensures the calibrator "sees" all expert weights and calculates a valid `_amax`. --- ### 2. Type of change * [x] Bug fix --- ### 3. Usage / Reproduction This issue is reproducible when running weight-only quantization on MoE models like Qwen3-30B-A3B: ```bash # Step 1: Quantization torchrun --nproc_per_node 8 examples/quantization/quantize.py \ --hf-model-id Qwen/Qwen3-30B-A3B \ --export-quant-cfg nvfp4 \ --tp 2 \ --ep 8 \ --weight-only \ --megatron-save-path ./qwen3_30b_nvfp4 ``` --- ### 4. Testing & Verification * **Models Tested:** Qwen3-8B (Dense), Qwen3-30B-A3B (MoE). * **Quantization:** NVFP4/FP8 weight-only quantization. * **Verification:** - Confirmed that `QuantTEGroupedMLP` now correctly shows calculated `_amax` values in the quantization statistics table instead of remaining `dynamic`. * Validated that the change does not regress dense model (Qwen3-8B) quantization flow. * After fix, the amax of experts can be calculated correctly. ``` Quantization Statistics ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┓ ┃ Parameter Name ┃ Shape ┃ Max Value ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━┩ │ decoder.layers.0.self_attention.linear_proj.weight_quantizer._amax │ () │ 7.5781e-01 │ │ decoder.layers.0.self_attention.linear_qkv.weight_quantizer._amax │ () │ 2.8711e-01 │ │ decoder.layers.0.mlp.experts.linear_fc1.weight_quantizer._amax │ () │ 7.1094e-01 │ │ decoder.layers.0.mlp.experts.linear_fc2.weight_quantizer._amax │ () │ 8.6719e-01 │ │ decoder.layers.1.self_attention.linear_proj.weight_quantizer._amax │ () │ 5.8594e-01 │ │ decoder.layers.1.self_attention.linear_qkv.weight_quantizer._amax │ () │ 7.4219e-01 │ │ decoder.layers.1.mlp.experts.linear_fc1.weight_quantizer._amax │ () │ 7.2266e-01 │ │ decoder.layers.1.mlp.experts.linear_fc2.weight_quantizer._amax │ () │ 1.9922e+00 │ │ decoder.layers.2.self_attention.linear_proj.weight_quantizer._amax │ () │ 1.0859e+00 │ │ decoder.layers.2.self_attention.linear_qkv.weight_quantizer._amax │ () │ 1.7812e+00 │ │ decoder.layers.2.mlp.experts.linear_fc1.weight_quantizer._amax │ () │ 7.3047e-01 │ │ decoder.layers.2.mlp.experts.linear_fc2.weight_quantizer._amax │ () │ 1.9219e+00 │ ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enhanced weight-only quantization calibration with improved support for specialized quantization modules and grouped-linear quantization paths. * **Bug Fixes** * Fixed handling of missing weight attributes during quantization calibration to prevent incorrect processing. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: larkzhang-nv <larkz@nvidia.com> Signed-off-by: larkz <larkz@nvidia.com>
1 parent 2252074 commit 62bde15

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,11 @@ def weight_only_quantize(model: nn.Module):
6767
for module in name_to_module.values():
6868
if module in seen_modules:
6969
continue
70-
for weight_name in weight_attr_names(module):
70+
71+
if isinstance(module, QuantModule):
7172
with enable_weight_access_and_writeback(module, model, name_to_module):
72-
weight_quantizer = getattr(
73-
module, quantizer_attr_names(weight_name).weight_quantizer
74-
)
75-
weight_quantizer(getattr(module, weight_name))
73+
for weight, weight_quantizer in module.iter_weights_for_calibration():
74+
weight_quantizer(weight)
7675
seen_modules.add(module)
7776

7877

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ def modelopt_post_restore(self, prefix: str = ""):
119119
if isinstance(module, TensorQuantizer):
120120
module.to(non_tq_param_or_buffer.device)
121121

122+
def iter_weights_for_calibration(self):
123+
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration."""
124+
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names
125+
126+
for weight_name in weight_attr_names(self):
127+
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
128+
yield getattr(self, weight_name), weight_quantizer
129+
122130
def fold_weight(self, keep_attrs: bool = False):
123131
"""Fold the weight for faster eval."""
124132
# Handle all attributes that end with _weight_quantizer

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ def modelopt_post_restore(self, prefix: str = ""):
151151
# Remove self.weight after post_restore.
152152
delattr(self, "weight")
153153

154+
def iter_weights_for_calibration(self):
155+
"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights."""
156+
for i in range(self.num_gemms):
157+
weight_i = getattr(self, f"weight{i}", None)
158+
if weight_i is not None:
159+
yield weight_i, self.weight_quantizer
160+
154161
@staticmethod
155162
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
156163
_assert_te_fp8_enabled()

modelopt/torch/quantization/utils/core_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
213213
# the standard weight and quantizer case
214214
weight = getattr(module, "weight", None)
215215
weight_quantizer = getattr(module, "weight_quantizer", None)
216-
if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
216+
if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
217217
yield "weight"
218218

219219
# other weight and quantizer case

0 commit comments

Comments
 (0)