Skip to content

Commit 3d5646d

Browse files
authored
Merge branch 'main' into jingyux/diffusion-skip-softmax
2 parents fa6d2ad + dc7ad66 commit 3d5646d

3 files changed

Lines changed: 39 additions & 4 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def make_calib_dataloader(
283283
include_labels = (
284284
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
285285
)
286+
286287
calib_dataloader = get_dataset_dataloader(
287288
dataset_name=args.dataset,
288289
tokenizer=tokenizer,

modelopt/torch/quantization/model_calib.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
reduce_amax,
5050
weight_attr_names,
5151
)
52-
from .utils.calib_utils import GPTQHelper
52+
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper
5353

5454
__all__ = [
5555
"awq",
@@ -1589,6 +1589,21 @@ def sequential_calibrate(
15891589

15901590
def _layer_forward_loop(m, _inputs=layer_inputs):
15911591
for args, kwargs_input in _inputs:
1592+
# Reset past_key_values to prevent the KV cache from
1593+
# accumulating across multiple forward replays (e.g.
1594+
# max_calibrate then Hessian collection in GPTQ).
1595+
# The layer doesn't need stale KV data — each replay
1596+
# should start with a fresh cache.
1597+
if (
1598+
"past_key_values" in kwargs_input
1599+
and kwargs_input["past_key_values"] is not None
1600+
):
1601+
kwargs_input = dict(kwargs_input)
1602+
cache = kwargs_input["past_key_values"]
1603+
if hasattr(cache, "reset"):
1604+
cache.reset()
1605+
else:
1606+
kwargs_input["past_key_values"] = None
15921607
m(*args, **kwargs_input)
15931608

15941609
calib_func(layer, _layer_forward_loop, **calib_kwargs)
@@ -1648,7 +1663,15 @@ def gptq(
16481663
print_rank_0("No quantized linear layers found, skipping GPTQ")
16491664
return
16501665

1651-
gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers}
1666+
def _make_gptq_handle(name, m):
1667+
backend = getattr(m.weight_quantizer, "backend", None)
1668+
if backend is None:
1669+
cls = GPTQHelper
1670+
else:
1671+
cls = _GPTQ_HELPER_REGISTRY.get(backend, GPTQHelper)
1672+
return cls(m, name, offload_to_cpu=True)
1673+
1674+
gptq_handles = {name: _make_gptq_handle(name, m) for name, m in quantized_layers}
16521675
for handle in gptq_handles.values():
16531676
handle.setup()
16541677

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def update_weights(self, block_size, perc_damp):
143143
hessian = self.hessian.to(self.module.weight.device)
144144
self.weight = self.module.weight.data.float().clone()
145145
self._prepare_hessian_inverse(hessian, perc_damp)
146-
147146
self._blockwise_update(block_size)
148-
149147
self._print_mse_error(hessian)
150148
self.module.weight.data = self.weight.reshape(self.module.weight.shape).to(
151149
self.module.weight.data.dtype
@@ -231,3 +229,16 @@ def _print_mse_error(self, hessian):
231229
mse = (delta).mm(hessian).mul(delta).mean() / (w_orig.mm(hessian).mul(w_orig).mean() + 1e-6)
232230
suffix = f", n_hessian_samples: {self.n_samples}" if self.n_samples else ""
233231
print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}")
232+
233+
234+
_GPTQ_HELPER_REGISTRY: dict[str, type[GPTQHelper]] = {}
235+
236+
237+
def register_gptq_helper(backend: str, factory: type[GPTQHelper]) -> None:
238+
"""Register a :class:`GPTQHelper` subclass for a quantizer backend.
239+
240+
When :func:`modelopt.torch.quantization.model_calib.gptq` encounters a
241+
module whose ``weight_quantizer.backend`` matches ``backend``, it will
242+
construct ``factory`` instead of the default ``GPTQHelper``.
243+
"""
244+
_GPTQ_HELPER_REGISTRY[backend] = factory

0 commit comments

Comments
 (0)