@@ -1545,7 +1545,7 @@ def _print_relative_mse_error(
15451545 delta = q - w
15461546 mse = (delta ).mm (h ).mul (delta ).mean () / (w .mm (h ).mul (w ).mean () + 1e-6 )
15471547 suffix = f", n_hessian_samples: { n_samples } " if n_samples is not None else ""
1548- print (f"[{ module_name } ] Relative MSE error: { mse .item ():.2e} { suffix } " )
1548+ print_rank_0 (f"[{ module_name } ] Relative MSE error: { mse .item ():.2e} { suffix } " )
15491549
15501550
15511551def update_hessian (input , hessian , n_samples ):
@@ -1604,7 +1604,7 @@ def prepare_hessian_inverse(h, weight, percdamp):
16041604 h = torch .cholesky_inverse (torch .linalg .cholesky (h ))
16051605 h_inv = torch .linalg .cholesky (h , upper = True )
16061606 except (RuntimeError , torch .linalg .LinAlgError ):
1607- print ("Warning: Hessian is not positive definite, using identity matrix" )
1607+ print_rank_0 ("Warning: Hessian is not positive definite, using identity matrix" )
16081608 h_inv = torch .eye (h .shape [0 ], device = h .device , dtype = h .dtype )
16091609 return h_inv
16101610
@@ -1706,37 +1706,104 @@ def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bou
17061706 return _column_qdq_tensor , True
17071707
17081708
1709+ def _can_use_fused_gptq (quantizer ) -> bool :
1710+ """Check whether the fused Triton GPTQ kernel can be used for *quantizer*."""
1711+ if not isinstance (quantizer , NVFP4StaticQuantizer ):
1712+ return False
1713+ if not hasattr (quantizer , "_amax" ) or quantizer ._amax is None :
1714+ return False
1715+ from modelopt .torch .quantization .triton import IS_AVAILABLE as _TRITON_OK
1716+
1717+ return _TRITON_OK
1718+
1719+
17091720def blockwise_weight_update (module , h , block_size , percdamp , n_samples = None ):
17101721 """Update module weights using GPTQ-style blockwise quantization.
17111722
1723+ Dispatches to one of three internal paths depending on quantizer type:
1724+
1725+ 1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is
1726+ available. Runs the entire column loop in a single GPU kernel per
1727+ block (~130x faster than the unfused path on Blackwell GPUs).
1728+ 2. **Column-QDQ** — for integer quantizers whose scale geometry allows
1729+ single-column fake-quant via :func:`_build_column_qdq`.
1730+ 3. **Full-matrix fallback** — calls the quantizer on the full weight matrix
1731+ each column (slowest, but always correct).
1732+
17121733 Args:
1713- module: Neural network module with weight and weight_quantizer
1714- H : Hessian matrix (d x d)
1715- block_size: Size of blocks to process at once
1716- percdamp: Damping percentage for Hessian diagonal
1717- n_samples: Number of Hessian samples for logging (optional)
1734+ module: Neural network module with `` weight`` and `` weight_quantizer``.
1735+ h : Hessian matrix of shape ``(d, d)``.
1736+ block_size: Number of columns processed per block.
1737+ percdamp: Damping as a fraction of the mean Hessian diagonal.
1738+ n_samples: Number of Hessian samples (used only for logging).
17181739 """
17191740 weight = module .weight .data .float ().clone ()
1720- _ , num_cols = weight .shape
1741+ num_rows , num_cols = weight .shape
17211742
1722- # Preprocess Hessian: handle dead neurons and add damping
17231743 h_inv = prepare_hessian_inverse (h , weight , percdamp )
17241744
1725- # Try to build fast column-wise qdq (avoids quantizing the full matrix per column)
1726- col_qdq_fn , col_qdq_supported = _build_column_qdq (module .weight_quantizer , weight .shape )
1745+ quantizer = module .weight_quantizer
1746+ if _can_use_fused_gptq (quantizer ):
1747+ _blockwise_weight_update_fused (weight , h_inv , quantizer , num_rows , num_cols , block_size )
1748+ else :
1749+ col_qdq_fn , col_qdq_supported = _build_column_qdq (quantizer , weight .shape )
1750+ _blockwise_weight_update_unfused (
1751+ weight , h_inv , quantizer , num_cols , block_size , col_qdq_fn , col_qdq_supported
1752+ )
1753+
1754+ _print_relative_mse_error (weight , module .weight .float (), h , module .name , n_samples )
1755+ module .weight .data = weight .reshape (module .weight .shape ).to (module .weight .data .dtype )
1756+
1757+
1758+ def _blockwise_weight_update_fused (weight , h_inv , quantizer , num_rows , num_cols , block_size ):
1759+ """Fused Triton path for NVFP4: one kernel launch per block."""
1760+ from modelopt .torch .quantization .triton .gptq_fused_kernel import gptq_fused_block
1761+
1762+ group_size = quantizer .block_sizes .get (- 1 , None ) or quantizer .block_sizes .get (1 , None )
1763+ num_groups = math .ceil (num_cols / group_size )
1764+ amax_grouped = quantizer ._amax .float ().reshape (num_rows , num_groups ).contiguous ()
1765+ global_amax = quantizer .global_amax .float ()
17271766
1728- # Process weights in blocks
17291767 for block_start in range (0 , num_cols , block_size ):
17301768 block_end = min (block_start + block_size , num_cols )
1731- n_cols = block_end - block_start
1769+ n_cols_blk = block_end - block_start
1770+
1771+ w_block = weight [:, block_start :block_end ].clone ().contiguous ()
1772+ h_inv_cho_blk = h_inv [block_start :block_end , block_start :block_end ].contiguous ()
1773+
1774+ qw_block , err_block = gptq_fused_block (
1775+ w_block ,
1776+ amax_grouped ,
1777+ global_amax ,
1778+ h_inv_cho_blk ,
1779+ group_size ,
1780+ block_start ,
1781+ n_cols_blk ,
1782+ )
1783+
1784+ weight [:, block_start :block_end ] = qw_block
1785+ if block_end < num_cols :
1786+ weight [:, block_end :].addmm_ (
1787+ err_block [:, :n_cols_blk ],
1788+ h_inv [block_start :block_end , block_end :],
1789+ alpha = - 1 ,
1790+ )
1791+
1792+
1793+ def _blockwise_weight_update_unfused (
1794+ weight , h_inv , quantizer , num_cols , block_size , col_qdq_fn , col_qdq_supported
1795+ ):
1796+ """Column-QDQ or full-matrix fallback for non-NVFP4 quantizers."""
1797+ for block_start in range (0 , num_cols , block_size ):
1798+ block_end = min (block_start + block_size , num_cols )
1799+ n_cols_blk = block_end - block_start
17321800 h_inv_cho_blk = h_inv [block_start :block_end , block_start :block_end ]
17331801
17341802 if col_qdq_supported :
1735- # Fast path: clone only the block columns, quantize only per-column
17361803 wblk = weight [:, block_start :block_end ].clone ()
17371804 errs = torch .zeros_like (wblk )
17381805
1739- for i in range (n_cols ):
1806+ for i in range (n_cols_blk ):
17401807 w_ci = wblk [:, i ]
17411808 d = h_inv_cho_blk [i , i ]
17421809 qdq_col = col_qdq_fn (w_ci , block_start + i )
@@ -1745,27 +1812,20 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None):
17451812 wblk [:, i :].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
17461813 errs [:, i ] = err
17471814 else :
1748- # Fallback: original full-matrix quantization path
17491815 wblk = weight .clone ()
17501816 errs = torch .zeros_like (wblk [:, block_start :block_end ])
17511817
1752- for i in range (n_cols ):
1818+ for i in range (n_cols_blk ):
17531819 w_ci = wblk [:, block_start + i ]
17541820 d = h_inv_cho_blk [i , i ]
1755- qdq = module . weight_quantizer (wblk )
1821+ qdq = quantizer (wblk )
17561822 weight [:, block_start + i ] = qdq [:, block_start + i ]
17571823 err = (w_ci - qdq [:, block_start + i ]) / d
17581824 wblk [:, block_start + i : block_end ].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
17591825 errs [:, i ] = err
17601826
1761- # Propagate errors to remaining weights
17621827 weight [:, block_end :].addmm_ (errs , h_inv [block_start :block_end , block_end :], alpha = - 1 )
17631828
1764- # Print relative mse error
1765- _print_relative_mse_error (weight , module .weight .float (), h , module .name , n_samples )
1766- # Update module weights
1767- module .weight .data = weight .reshape (module .weight .shape ).to (module .weight .data .dtype )
1768-
17691829
17701830def gptq_lite (
17711831 model : nn .Module ,
0 commit comments