3939"""GPTQ helper and Hessian utilities for calibration."""
4040
4141import math
42- from contextlib import contextmanager
4342
4443import torch
4544
@@ -135,22 +134,6 @@ def free(self):
135134 self .weight = None
136135 self .h_inv = None
137136
138- def is_vector_lut (self ) -> bool :
139- """Check if this module's weight quantizer is configured for vector LUT quantization."""
140- extra_args = getattr (self .module .weight_quantizer , "backend_extra_args" , None )
141- return bool (extra_args and extra_args .get ("lut_type" ) == "vector_lut" )
142-
143- @contextmanager
144- def weight_slice_scales (self , scales ):
145- """Temporarily replace _psx_scales with per-slice scales, then restore."""
146- quantizer = self .module .weight_quantizer
147- original = quantizer ._psx_scales
148- quantizer ._psx_scales = scales
149- try :
150- yield
151- finally :
152- quantizer ._psx_scales = original
153-
154137 def update_weights (self , block_size , perc_damp ):
155138 """Run GPTQ blockwise weight update on this module.
156139
@@ -160,12 +143,7 @@ def update_weights(self, block_size, perc_damp):
160143 hessian = self .hessian .to (self .module .weight .device )
161144 self .weight = self .module .weight .data .float ().clone ()
162145 self ._prepare_hessian_inverse (hessian , perc_damp )
163-
164- if self .is_vector_lut ():
165- self ._blockwise_vector_update (block_size )
166- else :
167- self ._blockwise_update (block_size )
168-
146+ self ._blockwise_update (block_size )
169147 self ._print_mse_error (hessian )
170148 self .module .weight .data = self .weight .reshape (self .module .weight .shape ).to (
171149 self .module .weight .data .dtype
@@ -244,62 +222,23 @@ def _blockwise_update(self, block_size):
244222 errs , self .h_inv [block_start :block_end , block_end :], alpha = - 1
245223 )
246224
247- def _blockwise_vector_update (self , block_size ):
248- """GPTQ blockwise update for vector quantizers.
249-
250- A single ``quantizer(weight)`` call computes and caches per-block
251- scales (``_psx_scales``) on the quantizer via the backend's
252- ``static_scales`` path. The GPTQ loop then slices per-vector-group
253- scales from ``_psx_scales`` for each sub-vector quantization call.
254- """
255- import torch .nn .functional as F
256-
257- quantizer = self .module .weight_quantizer
258- assert quantizer .backend_extra_args .get ("static_scales" , False ), (
259- "GPTQ vector update requires static_scales=True in backend_extra_args."
260- )
261- vector_size = quantizer .backend_extra_args ["vector_size" ]
262- quant_block_size = quantizer .backend_extra_args ["block_sizes" ]
263-
264- assert self .weight is not None and self .h_inv is not None
265- num_cols = self .weight .shape [1 ]
266- assert block_size % quant_block_size == 0
267-
268- # Compute and cache _psx_scales on the quantizer via the backend.
269- quantizer (self .weight )
270-
271- w = self .weight .clone ()
272- h_inv = self .h_inv
273-
274- for blk_start in range (0 , num_cols , block_size ):
275- blk_end = min (blk_start + block_size , num_cols )
276- errs = torch .zeros_like (w [:, blk_start :blk_end ])
277-
278- for j in range (blk_start , blk_end , vector_size ):
279- d = min (vector_size , blk_end - j )
280- s = quantizer ._psx_scales [:, j // quant_block_size ].contiguous ()
281-
282- sub = w [:, j : j + d ].contiguous ()
283- if d < vector_size :
284- sub = F .pad (sub , (0 , vector_size - d ))
285-
286- with self .weight_slice_scales (s ):
287- q_sub = quantizer (sub )
288-
289- for k in range (d ):
290- col = j + k
291- self .weight [:, col ] = q_sub [:, k ]
292- err = (w [:, col ] - q_sub [:, k ]) / h_inv [col , col ]
293- errs [:, col - blk_start ] = err
294- w [:, col :blk_end ].addr_ (err , h_inv [col , col :blk_end ], alpha = - 1 )
295-
296- if blk_end < num_cols :
297- w [:, blk_end :] -= errs @ h_inv [blk_start :blk_end , blk_end :]
298-
299225 def _print_mse_error (self , hessian ):
300226 """Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
301227 w_orig = self .module .weight .float ()
302228 delta = self .weight - w_orig
303229 mse = (delta ).mm (hessian ).mul (delta ).mean () / (w_orig .mm (hessian ).mul (w_orig ).mean () + 1e-6 )
304230 suffix = f", n_hessian_samples: { self .n_samples } " if self .n_samples else ""
305231 print_rank_0 (f"[{ self .name } ] Relative MSE error: { mse .item ():.2e} { suffix } " )
232+
233+
234+ _GPTQ_HELPER_REGISTRY : dict [str , type [GPTQHelper ]] = {}
235+
236+
237+ def register_gptq_helper (backend : str , factory : type [GPTQHelper ]) -> None :
238+ """Register a :class:`GPTQHelper` subclass for a quantizer backend.
239+
240+ When :func:`modelopt.torch.quantization.model_calib.gptq` encounters a
241+ module whose ``weight_quantizer.backend`` matches ``backend``, it will
242+ construct ``factory`` instead of the default ``GPTQHelper``.
243+ """
244+ _GPTQ_HELPER_REGISTRY [backend ] = factory
0 commit comments