@@ -126,10 +126,11 @@ class GPTQHelper:
126126
127127 CACHE_NAME = "_forward_no_gptq_hessian"
128128
129- def __init__ (self , module , name , offload_to_cpu = False ):
129+ def __init__ (self , module , name , offload_to_cpu = False , fused = False ):
130130 """Initialize GPTQHelper with module state and Hessian storage."""
131131 self .module = module
132132 self .name = name
133+ self .fused = fused
133134 in_features = module .weight .shape [- 1 ]
134135 device = module .weight .device
135136 if device .type == "meta" or (offload_to_cpu and get_used_gpu_mem_fraction (device ) > 0.65 ):
@@ -195,23 +196,35 @@ def _prepare_hessian_inverse(self, hessian, perc_damp):
195196 self .h_inv = compute_hessian_inverse (hessian , self .weight , perc_damp )
196197
197198 def _blockwise_update (self , block_size ):
198- """Column-wise GPTQ update using full-matrix QDQ .
199+ """Column-wise GPTQ update.
199200
200- Delegates to :func:`gptq_blockwise_update` with the module's weight quantizer.
201+ When ``self.fused`` is True and the weight quantizer is an
202+ ``NVFP4StaticQuantizer``, uses :func:`gptq_blockwise_update_fused_scalar`
203+ (a fused Triton kernel). Otherwise falls back to
204+ :func:`gptq_blockwise_update` (unfused column-by-column loop).
201205 """
202206 assert self .weight is not None and self .h_inv is not None , (
203207 "_blockwise_update called before _prepare_hessian_inverse()"
204208 )
205209 quantizer = self .module .weight_quantizer
206- block_sizes = getattr (quantizer , "block_sizes" , None )
207- if block_sizes is not None :
208- group_size = block_sizes .get (- 1 )
209- if group_size is not None and block_size % group_size != 0 :
210+
211+ if self .fused and getattr (quantizer , "_is_nvfp4_static_quantizer" , False ):
212+ block_sizes = quantizer .block_sizes
213+ quant_block_size = block_sizes .get (- 1 ) or block_sizes .get (1 )
214+ if quant_block_size is not None and block_size % quant_block_size != 0 :
210215 raise ValueError (
211216 f"GPTQ block_size ({ block_size } ) must be divisible by the quantizer"
212- f" group_size ({ group_size } )"
217+ f" group_size ({ quant_block_size } )"
213218 )
214- gptq_blockwise_update (self .weight , self .h_inv , block_size , quantizer )
219+ out_features , num_cols = self .weight .shape
220+ n_blocks = num_cols // quant_block_size
221+ block_amax = quantizer .amax .reshape (out_features , n_blocks ).float ()
222+ global_scale = quantizer .global_amax .float ().item () / (6.0 * 448.0 )
223+ gptq_blockwise_update_fused_scalar (
224+ self .weight , block_amax , global_scale , self .h_inv , block_size , quant_block_size
225+ )
226+ else :
227+ gptq_blockwise_update (self .weight , self .h_inv , block_size , quantizer )
215228
216229 def _print_mse_error (self , hessian ):
217230 """Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
@@ -260,17 +273,20 @@ def gptq_blockwise_update(weight, h_inv, block_size, quantize_fn):
260273 weight [:, block_end :].addmm_ (errs , h_inv [block_start :block_end , block_end :], alpha = - 1 )
261274
262275
263- def gptq_blockwise_update_fused_scalar (weight , scales_2d , h_inv , block_size , quant_block_size ):
276+ def gptq_blockwise_update_fused_scalar (
277+ weight , block_amax , global_scale , h_inv , block_size , quant_block_size
278+ ):
264279 """Fused GPTQ blockwise update for NVFP4 scalar quantization.
265280
266- Uses a fused Triton kernel that combines quantization and per-column
267- error propagation into one launch per GPTQ block, avoiding the
268- Python-level per-column loop in :func:`gptq_blockwise_update`.
281+ Uses a fused Triton kernel that combines scale computation, quantization,
282+ and per-column error propagation into one launch per GPTQ block, avoiding
283+ the Python-level per-column loop in :func:`gptq_blockwise_update`.
269284
270285 Args:
271286 weight: Weight tensor ``[out_features, in_features]``, modified **in-place**
272287 with fake-quantized values.
273- scales_2d: Pre-computed per-block scales ``[out_features, n_scale_blocks]``.
288+ block_amax: Per-block amax values ``[out_features, n_amax_blocks]``.
289+ global_scale: Pre-computed ``global_amax / (6.0 * 448.0)`` (scalar).
274290 h_inv: Upper-triangular Cholesky factor of the damped inverse Hessian
275291 ``[in_features, in_features]``.
276292 block_size: Number of columns to process per GPTQ block.
@@ -283,7 +299,8 @@ def gptq_blockwise_update_fused_scalar(weight, scales_2d, h_inv, block_size, qua
283299 be = min (bs + block_size , num_cols )
284300 qw , err = gptq_fused_block_scalar (
285301 weight [:, bs :be ].clone ().contiguous (),
286- scales_2d ,
302+ block_amax ,
303+ global_scale ,
287304 h_inv [bs :be , bs :be ].contiguous (),
288305 quant_block_size ,
289306 bs ,
@@ -293,44 +310,6 @@ def gptq_blockwise_update_fused_scalar(weight, scales_2d, h_inv, block_size, qua
293310 weight [:, be :].addmm_ (err , h_inv [bs :be , be :], alpha = - 1 )
294311
295312
296- class FusedScalarGPTQHelper (GPTQHelper ):
297- """GPTQHelper using the fused Triton kernel for NVFP4 scalar quantization.
298-
299- Overrides :meth:`_blockwise_update` to extract pre-computed scales from the
300- ``NVFP4StaticQuantizer`` and delegate to :func:`gptq_blockwise_update_fused_scalar`.
301- """
302-
303- def _blockwise_update (self , block_size ):
304- """Fused GPTQ using Triton kernel for NVFP4 scalar quantization."""
305- assert self .weight is not None and self .h_inv is not None , (
306- "_blockwise_update called before _prepare_hessian_inverse()"
307- )
308- from modelopt .torch .quantization .triton .fp4_kernel import compute_fp4_scales
309-
310- quantizer = self .module .weight_quantizer
311- block_sizes = getattr (quantizer , "block_sizes" , None )
312- quant_block_size = None
313- if block_sizes is not None :
314- quant_block_size = block_sizes .get (- 1 ) or block_sizes .get (1 )
315-
316- if quant_block_size is not None and block_size % quant_block_size != 0 :
317- raise ValueError (
318- f"GPTQ block_size ({ block_size } ) must be divisible by the quantizer"
319- f" group_size ({ quant_block_size } )"
320- )
321-
322- out_features , num_cols = self .weight .shape
323- n_blocks = num_cols // quant_block_size
324-
325- # Pre-compute scales from the calibrated amax (frozen during GPTQ).
326- amax = quantizer .amax .reshape (out_features , n_blocks )
327- scales_2d = compute_fp4_scales (amax , quantizer .global_amax , quantize_block_scales = True )
328-
329- gptq_blockwise_update_fused_scalar (
330- self .weight , scales_2d , self .h_inv , block_size , quant_block_size
331- )
332-
333-
334313_GPTQ_HELPER_REGISTRY : dict [str , type [GPTQHelper ]] = {}
335314
336315
@@ -342,7 +321,3 @@ def register_gptq_helper(backend: str, factory: type[GPTQHelper]) -> None:
342321 construct ``factory`` instead of the default ``GPTQHelper``.
343322 """
344323 _GPTQ_HELPER_REGISTRY [backend ] = factory
345-
346-
347- # Built-in registrations
348- register_gptq_helper ("fused_gptq_nvfp4" , FusedScalarGPTQHelper )
0 commit comments