Skip to content

Commit 45ab8ca

Browse files
committed
Fix nvfp4 weight-only quantization for TEGroupedMLP (MoE models)
Signed-off-by: larkzhang-nv <larkz@nvidia.com>
1 parent a34d613 commit 45ab8ca

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,18 @@ def weight_only_quantize(model: nn.Module):
7070
for name, module in model.named_modules():
7171
if module in seen_modules:
7272
continue
73-
for weight_name in weight_attr_names(module):
73+
74+
if isinstance(module, QuantModule):
7475
with enable_weight_access_and_writeback(module, model):
75-
weight_quantizer = getattr(
76-
module, quantizer_attr_names(weight_name).weight_quantizer
77-
)
78-
weight_quantizer(getattr(module, weight_name))
76+
for weight, weight_quantizer in module.iter_weights_for_calibration():
77+
weight_quantizer(weight)
78+
else:
79+
for weight_name in weight_attr_names(module):
80+
with enable_weight_access_and_writeback(module, model):
81+
weight_quantizer = getattr(
82+
module, quantizer_attr_names(weight_name).weight_quantizer
83+
)
84+
weight_quantizer(getattr(module, weight_name))
7985
seen_modules.add(module)
8086

8187

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,21 @@ def modelopt_post_restore(self, prefix: str = ""):
119119
if isinstance(module, TensorQuantizer):
120120
module.to(non_tq_param_or_buffer.device)
121121

122+
def iter_weights_for_calibration(self):
123+
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration.
124+
125+
The default implementation iterates over all weights returned by
126+
:func:`~modelopt.torch.quantization.utils.weight_attr_names`. Subclasses that
127+
store weights under non-standard attribute names (e.g.
128+
``_QuantTEGroupedLinear`` uses ``weight0``, ``weight1``, …) should
129+
override this method.
130+
"""
131+
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names
132+
133+
for weight_name in weight_attr_names(self):
134+
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
135+
yield getattr(self, weight_name), weight_quantizer
136+
122137
def fold_weight(self, keep_attrs: bool = False):
123138
"""Fold the weight for faster eval."""
124139
# Handle all attributes that end with _weight_quantizer

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,18 @@ def modelopt_post_restore(self, prefix: str = ""):
151151
# Remove self.weight after post_restore.
152152
delattr(self, "weight")
153153

154+
def iter_weights_for_calibration(self):
155+
"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights.
156+
157+
Override is needed because ``self.weight`` is removed in ``_setup``, so the
158+
base-class implementation (which relies on ``weight_attr_names``) would find
159+
no weights. Here we iterate over ``weight0``, ``weight1``, … directly.
160+
"""
161+
for i in range(self.num_gemms):
162+
weight_i = getattr(self, f"weight{i}", None)
163+
if weight_i is not None:
164+
yield weight_i, self.weight_quantizer
165+
154166
@staticmethod
155167
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
156168
_assert_te_fp8_enabled()

modelopt/torch/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
229229
# the standard weight and quantizer case
230230
weight = getattr(module, "weight", None)
231231
weight_quantizer = getattr(module, "weight_quantizer", None)
232-
if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
232+
if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
233233
yield "weight"
234234

235235
# other weight and quantizer case

0 commit comments

Comments
 (0)