@@ -74,6 +74,42 @@ def update_hessian(input, hessian, n_samples):
7474 return hessian , n_samples
7575
7676
77+ def compute_hessian_inverse (hessian , weight , perc_damp ):
78+ """Compute damped upper-Cholesky inverse Hessian.
79+
80+ Dead-neuron columns (all-zero in ``weight``) are zeroed in the
81+ Hessian before inversion, matching the FP-Quant reference:
82+ https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200
83+
84+ Args:
85+ hessian: Hessian matrix ``[in_features, in_features]``.
86+ weight: Weight matrix ``[out_features, in_features]`` for dead-neuron detection.
87+ perc_damp: Percentage of average Hessian diagonal for damping.
88+
89+ Returns:
90+ Upper-triangular Cholesky factor of the damped inverse Hessian
91+ ``[in_features, in_features]``. Falls back to the identity matrix
92+ when the Hessian is not positive definite.
93+ """
94+ h = hessian .clone ()
95+ zero_cols = torch .nonzero (weight .eq (0 ).all (dim = 0 )).unsqueeze (- 1 )
96+
97+ h [zero_cols , :] = 0
98+ h [:, zero_cols ] = 0
99+ h [zero_cols , zero_cols ] = 1
100+
101+ damp = perc_damp * torch .mean (torch .diag (h ))
102+ diag_indices = torch .arange (h .shape [0 ], device = h .device )
103+ h [diag_indices , diag_indices ] += damp
104+
105+ try :
106+ h = torch .cholesky_inverse (torch .linalg .cholesky (h ))
107+ return torch .linalg .cholesky (h , upper = True )
108+ except (RuntimeError , torch .linalg .LinAlgError ):
109+ print_rank_0 ("Warning: Hessian is not positive definite, using identity matrix" )
110+ return torch .eye (h .shape [0 ], device = h .device , dtype = h .dtype )
111+
112+
77113class GPTQHelper :
78114 """Encapsulates per-module GPTQ state and operations.
79115
@@ -154,38 +190,14 @@ def update_weights(self, block_size, perc_damp):
154190 # ------------------------------------------------------------------
155191
156192 def _prepare_hessian_inverse (self , hessian , perc_damp ):
157- """Compute damped inverse Hessian and store as ``self.h_inv``.
158-
159- Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the
160- Hessian before inversion, matching the FP-Quant reference:
161- https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200
162- """
193+ """Compute damped inverse Hessian and store as ``self.h_inv``."""
163194 assert self .weight is not None , "_prepare_hessian_inverse called before update_weights()"
164- h = hessian .clone ()
165- zero_cols = torch .nonzero (self .weight .eq (0 ).all (dim = 0 )).unsqueeze (- 1 )
166-
167- h [zero_cols , :] = 0
168- h [:, zero_cols ] = 0
169- h [zero_cols , zero_cols ] = 1
170-
171- damp = perc_damp * torch .mean (torch .diag (h ))
172- diag_indices = torch .arange (h .shape [0 ], device = h .device )
173- h [diag_indices , diag_indices ] += damp
174-
175- try :
176- h = torch .cholesky_inverse (torch .linalg .cholesky (h ))
177- self .h_inv = torch .linalg .cholesky (h , upper = True )
178- except (RuntimeError , torch .linalg .LinAlgError ):
179- print_rank_0 ("Warning: Hessian is not positive definite, using identity matrix" )
180- self .h_inv = torch .eye (h .shape [0 ], device = h .device , dtype = h .dtype )
195+ self .h_inv = compute_hessian_inverse (hessian , self .weight , perc_damp )
181196
182197 def _blockwise_update (self , block_size ):
183198 """Column-wise GPTQ update using full-matrix QDQ.
184199
185- For each column, quantizes the full weight matrix via the quantizer and
186- extracts the quantized column. This is the standard GPTQ approach.
187-
188- Reads/writes ``self.weight`` and ``self.h_inv`` in-place.
200+ Delegates to :func:`gptq_blockwise_update` with the module's weight quantizer.
189201 """
190202 assert self .weight is not None and self .h_inv is not None , (
191203 "_blockwise_update called before _prepare_hessian_inverse()"
@@ -199,28 +211,7 @@ def _blockwise_update(self, block_size):
199211 f"GPTQ block_size ({ block_size } ) must be divisible by the quantizer"
200212 f" group_size ({ group_size } )"
201213 )
202- num_cols = self .weight .shape [1 ]
203-
204- for block_start in range (0 , num_cols , block_size ):
205- block_end = min (block_start + block_size , num_cols )
206- n_cols_blk = block_end - block_start
207- h_inv_cho_blk = self .h_inv [block_start :block_end , block_start :block_end ]
208-
209- wblk = self .weight .clone ()
210- errs = torch .zeros_like (wblk [:, block_start :block_end ])
211-
212- for i in range (n_cols_blk ):
213- w_ci = wblk [:, block_start + i ]
214- d = h_inv_cho_blk [i , i ]
215- qdq = quantizer (wblk )
216- self .weight [:, block_start + i ] = qdq [:, block_start + i ]
217- err = (w_ci - qdq [:, block_start + i ]) / d
218- wblk [:, block_start + i : block_end ].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
219- errs [:, i ] = err
220-
221- self .weight [:, block_end :].addmm_ (
222- errs , self .h_inv [block_start :block_end , block_end :], alpha = - 1
223- )
214+ gptq_blockwise_update (self .weight , self .h_inv , block_size , quantizer )
224215
225216 def _print_mse_error (self , hessian ):
226217 """Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
@@ -231,6 +222,115 @@ def _print_mse_error(self, hessian):
231222 print_rank_0 (f"[{ self .name } ] Relative MSE error: { mse .item ():.2e} { suffix } " )
232223
233224
225+ def gptq_blockwise_update (weight , h_inv , block_size , quantize_fn ):
226+ """Column-wise GPTQ update using full-matrix fake quantization.
227+
228+ For each column, quantizes the full weight matrix via ``quantize_fn`` and
229+ extracts the quantized column. Error is propagated to remaining columns
230+ within the block and then to all subsequent columns via the inverse Hessian.
231+
232+ Args:
233+ weight: Weight tensor ``[out_features, in_features]``, modified **in-place**
234+ with fake-quantized values.
235+ h_inv: Upper-triangular Cholesky factor of the damped inverse Hessian
236+ ``[in_features, in_features]``.
237+ block_size: Number of columns to process per GPTQ block.
238+ quantize_fn: Callable ``(weight) -> qdq_weight`` that fake-quantizes
239+ the full weight matrix.
240+ """
241+ num_cols = weight .shape [1 ]
242+
243+ for block_start in range (0 , num_cols , block_size ):
244+ block_end = min (block_start + block_size , num_cols )
245+ n_cols_blk = block_end - block_start
246+ h_inv_cho_blk = h_inv [block_start :block_end , block_start :block_end ]
247+
248+ wblk = weight .clone ()
249+ errs = torch .zeros_like (weight [:, block_start :block_end ])
250+
251+ for i in range (n_cols_blk ):
252+ w_ci = wblk [:, block_start + i ]
253+ d = h_inv_cho_blk [i , i ]
254+ qdq = quantize_fn (wblk )
255+ weight [:, block_start + i ] = qdq [:, block_start + i ]
256+ err = (w_ci - qdq [:, block_start + i ]) / d
257+ wblk [:, block_start + i : block_end ].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
258+ errs [:, i ] = err
259+
260+ weight [:, block_end :].addmm_ (errs , h_inv [block_start :block_end , block_end :], alpha = - 1 )
261+
262+
263+ def gptq_blockwise_update_fused_scalar (weight , scales_2d , h_inv , block_size , quant_block_size ):
264+ """Fused GPTQ blockwise update for NVFP4 scalar quantization.
265+
266+ Uses a fused Triton kernel that combines quantization and per-column
267+ error propagation into one launch per GPTQ block, avoiding the
268+ Python-level per-column loop in :func:`gptq_blockwise_update`.
269+
270+ Args:
271+ weight: Weight tensor ``[out_features, in_features]``, modified **in-place**
272+ with fake-quantized values.
273+ scales_2d: Pre-computed per-block scales ``[out_features, n_scale_blocks]``.
274+ h_inv: Upper-triangular Cholesky factor of the damped inverse Hessian
275+ ``[in_features, in_features]``.
276+ block_size: Number of columns to process per GPTQ block.
277+ quant_block_size: Number of elements sharing one quantization scale factor.
278+ """
279+ from modelopt .torch .quantization .triton .gptq_fused_kernel import gptq_fused_block_scalar
280+
281+ num_cols = weight .shape [1 ]
282+ for bs in range (0 , num_cols , block_size ):
283+ be = min (bs + block_size , num_cols )
284+ qw , err = gptq_fused_block_scalar (
285+ weight [:, bs :be ].clone ().contiguous (),
286+ scales_2d ,
287+ h_inv [bs :be , bs :be ].contiguous (),
288+ quant_block_size ,
289+ bs ,
290+ )
291+ weight [:, bs :be ] = qw
292+ if be < num_cols :
293+ weight [:, be :].addmm_ (err , h_inv [bs :be , be :], alpha = - 1 )
294+
295+
296+ class FusedScalarGPTQHelper (GPTQHelper ):
297+ """GPTQHelper using the fused Triton kernel for NVFP4 scalar quantization.
298+
299+ Overrides :meth:`_blockwise_update` to extract pre-computed scales from the
300+ ``NVFP4StaticQuantizer`` and delegate to :func:`gptq_blockwise_update_fused_scalar`.
301+ """
302+
303+ def _blockwise_update (self , block_size ):
304+ """Fused GPTQ using Triton kernel for NVFP4 scalar quantization."""
305+ assert self .weight is not None and self .h_inv is not None , (
306+ "_blockwise_update called before _prepare_hessian_inverse()"
307+ )
308+ from modelopt .torch .quantization .triton .fp4_kernel import compute_fp4_scales
309+
310+ quantizer = self .module .weight_quantizer
311+ block_sizes = getattr (quantizer , "block_sizes" , None )
312+ quant_block_size = None
313+ if block_sizes is not None :
314+ quant_block_size = block_sizes .get (- 1 ) or block_sizes .get (1 )
315+
316+ if quant_block_size is not None and block_size % quant_block_size != 0 :
317+ raise ValueError (
318+ f"GPTQ block_size ({ block_size } ) must be divisible by the quantizer"
319+ f" group_size ({ quant_block_size } )"
320+ )
321+
322+ out_features , num_cols = self .weight .shape
323+ n_blocks = num_cols // quant_block_size
324+
325+ # Pre-compute scales from the calibrated amax (frozen during GPTQ).
326+ amax = quantizer .amax .reshape (out_features , n_blocks )
327+ scales_2d = compute_fp4_scales (amax , quantizer .global_amax , quantize_block_scales = True )
328+
329+ gptq_blockwise_update_fused_scalar (
330+ self .weight , scales_2d , self .h_inv , block_size , quant_block_size
331+ )
332+
333+
234334_GPTQ_HELPER_REGISTRY : dict [str , type [GPTQHelper ]] = {}
235335
236336
@@ -242,3 +342,7 @@ def register_gptq_helper(backend: str, factory: type[GPTQHelper]) -> None:
242342 construct ``factory`` instead of the default ``GPTQHelper``.
243343 """
244344 _GPTQ_HELPER_REGISTRY [backend ] = factory
345+
346+
347+ # Built-in registrations
348+ register_gptq_helper ("fused_gptq_nvfp4" , FusedScalarGPTQHelper )
0 commit comments