3939"""GPTQ helper and Hessian utilities for calibration."""
4040
4141import math
42+ from contextlib import contextmanager
4243
4344import torch
4445
4748from modelopt .torch .utils .perf import get_used_gpu_mem_fraction
4849
4950
50- def load_vector_lut_codebook (quantizer ):
51- """Load vector LUT codebook and quantizer params from a weight_quantizer.
52-
53- Returns:
54- Tuple of (codebook, quant_block_size, scale_type).
55- """
56- from luts import encode as luts_encode
57-
58- extra_args = quantizer .backend_extra_args
59- encode_format = quantizer .num_bits
60- encode_path = extra_args .get ("encode_path" , "" )
61- if encode_path and not encode_path .endswith ("/" ):
62- encode_path += "/"
63-
64- if "sorted" in encode_format :
65- cb = torch .load (encode_path + encode_format + ".pt" , map_location = "cpu" )
66- codebook = cb ["sorted_values" ].cuda ().float ()
67- else :
68- codebook , _ = luts_encode (encode_format , path = encode_path , norm = False , cuda = True )
69- codebook = codebook .float ()
70-
71- return codebook , extra_args .get ("block_sizes" ), extra_args .get ("scale_type" )
72-
73-
7451def update_hessian (input , hessian , n_samples ):
7552 """Update hessian matrix with new input samples using incremental formula.
7653
@@ -158,21 +135,33 @@ def free(self):
158135 self .weight = None
159136 self .h_inv = None
160137
138+ def is_vector_lut (self ) -> bool :
139+ """Check if this module's weight quantizer is configured for vector LUT quantization."""
140+ extra_args = getattr (self .module .weight_quantizer , "backend_extra_args" , None )
141+ return bool (extra_args and extra_args .get ("lut_type" ) == "vector_lut" )
142+
143+ @contextmanager
144+ def weight_slice_scales (self , scales ):
145+ """Temporarily replace _psx_scales with per-slice scales, then restore."""
146+ quantizer = self .module .weight_quantizer
147+ original = quantizer ._psx_scales
148+ quantizer ._psx_scales = scales
149+ try :
150+ yield
151+ finally :
152+ quantizer ._psx_scales = original
153+
161154 def update_weights (self , block_size , perc_damp ):
162155 """Run GPTQ blockwise weight update on this module.
163156
164157 Populates ``self.weight`` and ``self.h_inv``, runs the blockwise update,
165158 logs MSE, and writes the result back to the module.
166159 """
167- backend_extra_args = getattr (self .module .weight_quantizer , "backend_extra_args" , None )
168- is_vector_lut = bool (
169- backend_extra_args and backend_extra_args .get ("lut_type" ) == "vector_lut"
170- )
171160 hessian = self .hessian .to (self .module .weight .device )
172161 self .weight = self .module .weight .data .float ().clone ()
173162 self ._prepare_hessian_inverse (hessian , perc_damp )
174163
175- if is_vector_lut :
164+ if self . is_vector_lut () :
176165 self ._blockwise_vector_update (block_size )
177166 else :
178167 self ._blockwise_update (block_size )
@@ -256,36 +245,28 @@ def _blockwise_update(self, block_size):
256245 )
257246
258247 def _blockwise_vector_update (self , block_size ):
259- """GPTQ blockwise update for vector LUT quantizers.
248+ """GPTQ blockwise update for vector quantizers.
260249
261- Pre-computes scales once, then runs the standard GPTQ 3-loop
262- with per-vector-group static quantization via clip_vector_prescaled.
250+ A single ``quantizer(weight)`` call computes and caches per-block
251+ scales (``_psx_scales``) on the quantizer via the backend's
252+ ``static_scales`` path. The GPTQ loop then slices per-vector-group
253+ scales from ``_psx_scales`` for each sub-vector quantization call.
263254 """
264255 import torch .nn .functional as F
265- from luts import clip_vector_prescaled , clip_vector_scalesign_fast
266256
267- codebook , quant_block_size , scale_type = load_vector_lut_codebook (
268- self .module .weight_quantizer
257+ quantizer = self .module .weight_quantizer
258+ assert quantizer .backend_extra_args .get ("static_scales" , False ), (
259+ "GPTQ vector update requires static_scales=True in backend_extra_args."
269260 )
270-
271- # Get vector size from codebook
272- vector_size = codebook .shape [1 ]
261+ vector_size = quantizer .backend_extra_args ["vector_size" ]
262+ quant_block_size = quantizer .backend_extra_args ["block_sizes" ]
273263
274264 assert self .weight is not None and self .h_inv is not None
275265 num_cols = self .weight .shape [1 ]
276266 assert block_size % quant_block_size == 0
277267
278- # Pre-compute scales once outside the GPTQ loop
279- _ , scales = clip_vector_scalesign_fast (
280- self .weight ,
281- codebook ,
282- quant_block_size ,
283- scale_type ,
284- scale_algo = "max" ,
285- sign_scale = True ,
286- return_scales = True ,
287- )
288- scales_2d = scales .reshape (self .weight .shape [0 ], - 1 )
268+ # Compute and cache _psx_scales on the quantizer via the backend.
269+ quantizer (self .weight )
289270
290271 w = self .weight .clone ()
291272 h_inv = self .h_inv
@@ -296,12 +277,14 @@ def _blockwise_vector_update(self, block_size):
296277
297278 for j in range (blk_start , blk_end , vector_size ):
298279 d = min (vector_size , blk_end - j )
299- s = scales_2d [:, j // quant_block_size ].contiguous ()
280+ s = quantizer . _psx_scales [:, j // quant_block_size ].contiguous ()
300281
301282 sub = w [:, j : j + d ].contiguous ()
302283 if d < vector_size :
303284 sub = F .pad (sub , (0 , vector_size - d ))
304- q_sub = clip_vector_prescaled (sub , codebook , s )
285+
286+ with self .weight_slice_scales (s ):
287+ q_sub = quantizer (sub )
305288
306289 for k in range (d ):
307290 col = j + k
0 commit comments