|
49 | 49 | reduce_amax, |
50 | 50 | weight_attr_names, |
51 | 51 | ) |
52 | | -from .utils.calib_utils import GPTQHelper |
| 52 | +from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper |
53 | 53 |
|
54 | 54 | __all__ = [ |
55 | 55 | "awq", |
@@ -1589,6 +1589,21 @@ def sequential_calibrate( |
1589 | 1589 |
|
1590 | 1590 | def _layer_forward_loop(m, _inputs=layer_inputs): |
1591 | 1591 | for args, kwargs_input in _inputs: |
| 1592 | + # Reset past_key_values to prevent the KV cache from |
| 1593 | + # accumulating across multiple forward replays (e.g. |
| 1594 | + # max_calibrate then Hessian collection in GPTQ). |
| 1595 | + # The layer doesn't need stale KV data — each replay |
| 1596 | + # should start with a fresh cache. |
| 1597 | + if ( |
| 1598 | + "past_key_values" in kwargs_input |
| 1599 | + and kwargs_input["past_key_values"] is not None |
| 1600 | + ): |
| 1601 | + kwargs_input = dict(kwargs_input) |
| 1602 | + cache = kwargs_input["past_key_values"] |
| 1603 | + if hasattr(cache, "reset"): |
| 1604 | + cache.reset() |
| 1605 | + else: |
| 1606 | + kwargs_input["past_key_values"] = None |
1592 | 1607 | m(*args, **kwargs_input) |
1593 | 1608 |
|
1594 | 1609 | calib_func(layer, _layer_forward_loop, **calib_kwargs) |
@@ -1648,7 +1663,15 @@ def gptq( |
1648 | 1663 | print_rank_0("No quantized linear layers found, skipping GPTQ") |
1649 | 1664 | return |
1650 | 1665 |
|
1651 | | - gptq_handles = {name: GPTQHelper(m, name, offload_to_cpu=True) for name, m in quantized_layers} |
| 1666 | + def _make_gptq_handle(name, m): |
| 1667 | + backend = getattr(m.weight_quantizer, "backend", None) |
| 1668 | + if backend is None: |
| 1669 | + cls = GPTQHelper |
| 1670 | + else: |
| 1671 | + cls = _GPTQ_HELPER_REGISTRY.get(backend, GPTQHelper) |
| 1672 | + return cls(m, name, offload_to_cpu=True) |
| 1673 | + |
| 1674 | + gptq_handles = {name: _make_gptq_handle(name, m) for name, m in quantized_layers} |
1652 | 1675 | for handle in gptq_handles.values(): |
1653 | 1676 | handle.setup() |
1654 | 1677 |
|
|
0 commit comments