Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,11 @@ def weight_only_quantize(model: nn.Module):
for module in name_to_module.values():
if module in seen_modules:
continue
for weight_name in weight_attr_names(module):

if isinstance(module, QuantModule):
with enable_weight_access_and_writeback(module, model, name_to_module):
weight_quantizer = getattr(
module, quantizer_attr_names(weight_name).weight_quantizer
)
weight_quantizer(getattr(module, weight_name))
for weight, weight_quantizer in module.iter_weights_for_calibration():
weight_quantizer(weight)
seen_modules.add(module)


Expand Down
8 changes: 8 additions & 0 deletions modelopt/torch/quantization/nn/modules/quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ def modelopt_post_restore(self, prefix: str = ""):
if isinstance(module, TensorQuantizer):
module.to(non_tq_param_or_buffer.device)

def iter_weights_for_calibration(self):
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration."""
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names

for weight_name in weight_attr_names(self):
weight_quantizer = getattr(self, quantizer_attr_names(weight_name).weight_quantizer)
yield getattr(self, weight_name), weight_quantizer

def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
# Handle all attributes that end with _weight_quantizer
Expand Down
7 changes: 7 additions & 0 deletions modelopt/torch/quantization/plugins/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def modelopt_post_restore(self, prefix: str = ""):
# Remove self.weight after post_restore.
delattr(self, "weight")

def iter_weights_for_calibration(self):
Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unnecessary because _ParallelLinear inherits from QuantModule, so the iter_weights_for_calibration defined in QuantModule will be inherited
@jQizhang

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jenchen13 thanks for the review! The reason for overriding iter_weights_for_calibration is that the base implementation is not compatible with _QuantTEGroupedLinear.
_QuantTEGroupedLinear doesn't have a self.weight attribute. The actual weights are stored as weight0, weight1, ..... The base QuantModule.iter_weights_for_calibration relies on weight_attr_names(), which checks for self.weight. So without this override, weight calibration would be silently skipped for grouped linear layers.

"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights."""
for i in range(self.num_gemms):
weight_i = getattr(self, f"weight{i}", None)
if weight_i is not None:
yield weight_i, self.weight_quantizer

@staticmethod
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
_assert_te_fp8_enabled()
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
# the standard weight and quantizer case
weight = getattr(module, "weight", None)
weight_quantizer = getattr(module, "weight_quantizer", None)
if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
yield "weight"

# other weight and quantizer case
Expand Down
Loading