@@ -1609,6 +1609,103 @@ def prepare_hessian_inverse(h, weight, percdamp):
16091609 return h_inv
16101610
16111611
1612+ def _build_column_qdq (quantizer , weight_shape ):
1613+ """Build a fast column-wise quantize-dequantize function for integer quantizers.
1614+
1615+ Instead of calling the full TensorQuantizer on the entire weight matrix (which
1616+ quantizes all elements) and extracting one column, this returns a closure that
1617+ quantizes only a single column using the quantizer's pre-computed amax/scales.
1618+
1619+ Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a
1620+ single column with the same fixed scale gives bit-identical results to
1621+ quantizing the full matrix and extracting that column.
1622+
1623+ Args:
1624+ quantizer: The weight TensorQuantizer (already calibrated).
1625+ weight_shape: Shape of the weight tensor (out_features, in_features).
1626+
1627+ Returns:
1628+ Tuple of (column_qdq_fn, supported) where:
1629+ - column_qdq_fn(column, col_idx) -> qdq_column (if supported)
1630+ - supported: True if column-wise qdq is available, False to fall back.
1631+ """
1632+ # Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple)
1633+ if isinstance (quantizer , NVFP4StaticQuantizer ):
1634+ return None , False
1635+ if isinstance (quantizer ._num_bits , tuple ):
1636+ return None , False
1637+
1638+ # Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns
1639+ if getattr (quantizer , "pre_quant_scale" , None ) is not None :
1640+ return None , False
1641+ if getattr (quantizer , "rotate_is_enabled" , False ):
1642+ return None , False
1643+
1644+ # Need calibrated amax
1645+ if not hasattr (quantizer , "_amax" ) or quantizer ._amax is None :
1646+ return None , False
1647+
1648+ num_bits = quantizer ._num_bits
1649+ unsigned = getattr (quantizer , "_unsigned" , False )
1650+ narrow_range = getattr (quantizer , "_narrow_range" , False )
1651+ max_bound = (2 ** (num_bits - 1 + int (unsigned ))) - 1
1652+ min_bound = - max_bound + int (narrow_range )
1653+
1654+ amax = quantizer ._amax .float ()
1655+ out_features , in_features = weight_shape
1656+
1657+ # Determine quantization geometry from block_sizes
1658+ block_sizes = quantizer .block_sizes
1659+ group_size = None
1660+ if block_sizes is not None :
1661+ # Skip dynamic block quantization
1662+ if block_sizes .get ("type" , "static" ) == "dynamic" :
1663+ return None , False
1664+ group_size = block_sizes .get (- 1 , None ) or block_sizes .get (len (weight_shape ) - 1 , None )
1665+
1666+ if group_size is not None and group_size > 0 :
1667+ # Per-group block quantization along last dim.
1668+ # After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,).
1669+ # amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size.
1670+ if in_features % group_size != 0 :
1671+ return None , False # Padding case — fall back
1672+
1673+ n_groups = in_features // group_size
1674+
1675+ try :
1676+ # Reshape amax to (out_features, n_groups) for O(1) group lookup
1677+ amax_2d = amax .reshape (out_features , n_groups )
1678+ except RuntimeError :
1679+ return None , False
1680+
1681+ def _column_qdq_group (
1682+ col , col_idx , _a = amax_2d , _mx = max_bound , _mn = min_bound , _gs = group_size
1683+ ):
1684+ col_scale = _mx / _a [:, col_idx // _gs ].clamp (min = 1e-12 )
1685+ return torch .clamp (torch .round (col * col_scale ), _mn , _mx ) / col_scale
1686+
1687+ return _column_qdq_group , True
1688+
1689+ # Per-channel (axis != None) or per-tensor (axis == None)
1690+ axis = quantizer .axis
1691+ if axis is not None :
1692+ # Per-channel: amax has shape (out_features, 1) or similar
1693+ col_scale = max_bound / amax .reshape (- 1 ).clamp (min = 1e-12 )
1694+
1695+ def _column_qdq_channel (col , col_idx , _s = col_scale , _mx = max_bound , _mn = min_bound ):
1696+ return torch .clamp (torch .round (col * _s ), _mn , _mx ) / _s
1697+
1698+ return _column_qdq_channel , True
1699+
1700+ # Per-tensor: single scalar scale
1701+ scalar_scale = max_bound / amax .clamp (min = 1e-12 ).item ()
1702+
1703+ def _column_qdq_tensor (col , col_idx , _s = scalar_scale , _mx = max_bound , _mn = min_bound ):
1704+ return torch .clamp (torch .round (col * _s ), _mn , _mx ) / _s
1705+
1706+ return _column_qdq_tensor , True
1707+
1708+
16121709def blockwise_weight_update (module , h , block_size , percdamp , n_samples = None ):
16131710 """Update module weights using GPTQ-style blockwise quantization.
16141711
@@ -1625,22 +1722,41 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None):
16251722 # Preprocess Hessian: handle dead neurons and add damping
16261723 h_inv = prepare_hessian_inverse (h , weight , percdamp )
16271724
1725+ # Try to build fast column-wise qdq (avoids quantizing the full matrix per column)
1726+ col_qdq_fn , col_qdq_supported = _build_column_qdq (module .weight_quantizer , weight .shape )
1727+
16281728 # Process weights in blocks
16291729 for block_start in range (0 , num_cols , block_size ):
16301730 block_end = min (block_start + block_size , num_cols )
16311731 n_cols = block_end - block_start
1632- wblk = weight .clone ()
1633- errs = torch .zeros_like (wblk [:, block_start :block_end ])
16341732 h_inv_cho_blk = h_inv [block_start :block_end , block_start :block_end ]
16351733
1636- for i in range (n_cols ):
1637- w_ci = wblk [:, block_start + i ]
1638- d = h_inv_cho_blk [i , i ]
1639- qdq = module .weight_quantizer (wblk )
1640- weight [:, block_start + i ] = qdq [:, block_start + i ]
1641- err = (w_ci - qdq [:, block_start + i ]) / d
1642- wblk [:, block_start + i : block_end ].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
1643- errs [:, i ] = err
1734+ if col_qdq_supported :
1735+ # Fast path: clone only the block columns, quantize only per-column
1736+ wblk = weight [:, block_start :block_end ].clone ()
1737+ errs = torch .zeros_like (wblk )
1738+
1739+ for i in range (n_cols ):
1740+ w_ci = wblk [:, i ]
1741+ d = h_inv_cho_blk [i , i ]
1742+ qdq_col = col_qdq_fn (w_ci , block_start + i )
1743+ weight [:, block_start + i ] = qdq_col
1744+ err = (w_ci - qdq_col ) / d
1745+ wblk [:, i :].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
1746+ errs [:, i ] = err
1747+ else :
1748+ # Fallback: original full-matrix quantization path
1749+ wblk = weight .clone ()
1750+ errs = torch .zeros_like (wblk [:, block_start :block_end ])
1751+
1752+ for i in range (n_cols ):
1753+ w_ci = wblk [:, block_start + i ]
1754+ d = h_inv_cho_blk [i , i ]
1755+ qdq = module .weight_quantizer (wblk )
1756+ weight [:, block_start + i ] = qdq [:, block_start + i ]
1757+ err = (w_ci - qdq [:, block_start + i ]) / d
1758+ wblk [:, block_start + i : block_end ].addr_ (err , h_inv_cho_blk [i , i :], alpha = - 1 )
1759+ errs [:, i ] = err
16441760
16451761 # Propagate errors to remaining weights
16461762 weight [:, block_end :].addmm_ (errs , h_inv [block_start :block_end , block_end :], alpha = - 1 )
@@ -1844,7 +1960,7 @@ def _layer_forward_loop(m, _inputs=layer_inputs):
18441960 torch .cuda .empty_cache ()
18451961 finally :
18461962 input_getter ._unpatch_all_layers ()
1847-
1963+
18481964 print_rank_0 ("Sequential calibration completed" )
18491965
18501966
@@ -1969,9 +2085,9 @@ def hessian_forward(self, input, *args, **kwargs):
19692085 blockwise_weight_update (
19702086 module , hessian , block_size , percdamp , n_samples = state ["n_samples" ]
19712087 )
1972- # Free memory
19732088 del hessian_state [module .name ]
1974- torch .cuda .empty_cache ()
2089+ if torch .cuda .is_available ():
2090+ torch .cuda .empty_cache ()
19752091
19762092 torch .cuda .synchronize () if torch .cuda .is_available () else None
19772093 weight_update_time = time .time () - weight_update_start
0 commit comments