Skip to content

Commit 83b7319

Browse files
committed
minor fix
Signed-off-by: larkz <larkz@nvidia.com>
1 parent 052e360 commit 83b7319

File tree

3 files changed

+2
-21
lines changed

3 files changed

+2
-21
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,6 @@ def weight_only_quantize(model: nn.Module):
7272
with enable_weight_access_and_writeback(module, model, name_to_module):
7373
for weight, weight_quantizer in module.iter_weights_for_calibration():
7474
weight_quantizer(weight)
75-
else:
76-
for weight_name in weight_attr_names(module):
77-
with enable_weight_access_and_writeback(module, model, name_to_module):
78-
weight_quantizer = getattr(
79-
module, quantizer_attr_names(weight_name).weight_quantizer
80-
)
81-
weight_quantizer(getattr(module, weight_name))
8275
seen_modules.add(module)
8376

8477

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,7 @@ def modelopt_post_restore(self, prefix: str = ""):
120120
module.to(non_tq_param_or_buffer.device)
121121

122122
def iter_weights_for_calibration(self):
123-
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration.
124-
125-
The default implementation iterates over all weights returned by
126-
:func:`~modelopt.torch.quantization.utils.weight_attr_names`. Subclasses that
127-
store weights under non-standard attribute names (e.g.
128-
``_QuantTEGroupedLinear`` uses ``weight0``, ``weight1``, …) should
129-
override this method.
130-
"""
123+
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration."""
131124
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names
132125

133126
for weight_name in weight_attr_names(self):

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,7 @@ def modelopt_post_restore(self, prefix: str = ""):
152152
delattr(self, "weight")
153153

154154
def iter_weights_for_calibration(self):
155-
"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights.
156-
157-
Override is needed because ``self.weight`` is removed in ``_setup``, so the
158-
base-class implementation (which relies on ``weight_attr_names``) would find
159-
no weights. Here we iterate over ``weight0``, ``weight1``, … directly.
160-
"""
155+
"""Yield ``(weight_i, weight_quantizer)`` for each of the ``num_gemms`` grouped weights."""
161156
for i in range(self.num_gemms):
162157
weight_i = getattr(self, f"weight{i}", None)
163158
if weight_i is not None:

0 commit comments

Comments
 (0)