Skip to content

Commit 806e8ac

Browse files
committed
update
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent fec8f89 commit 806e8ac

1 file changed

Lines changed: 129 additions & 13 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 129 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,103 @@ def prepare_hessian_inverse(h, weight, percdamp):
16091609
return h_inv
16101610

16111611

1612+
def _build_column_qdq(quantizer, weight_shape):
1613+
"""Build a fast column-wise quantize-dequantize function for integer quantizers.
1614+
1615+
Instead of calling the full TensorQuantizer on the entire weight matrix (which
1616+
quantizes all elements) and extracting one column, this returns a closure that
1617+
quantizes only a single column using the quantizer's pre-computed amax/scales.
1618+
1619+
Since max_calibrate fixes the amax before GPTQ weight updates, quantizing a
1620+
single column with the same fixed scale gives bit-identical results to
1621+
quantizing the full matrix and extracting that column.
1622+
1623+
Args:
1624+
quantizer: The weight TensorQuantizer (already calibrated).
1625+
weight_shape: Shape of the weight tensor (out_features, in_features).
1626+
1627+
Returns:
1628+
Tuple of (column_qdq_fn, supported) where:
1629+
- column_qdq_fn(column, col_idx) -> qdq_column (if supported)
1630+
- supported: True if column-wise qdq is available, False to fall back.
1631+
"""
1632+
# Unsupported: NVFP4 (two-level FP4 scaling), FP quantization (num_bits is a tuple)
1633+
if isinstance(quantizer, NVFP4StaticQuantizer):
1634+
return None, False
1635+
if isinstance(quantizer._num_bits, tuple):
1636+
return None, False
1637+
1638+
# Unsupported: pre_quant_scale (SmoothQuant) or rotation transforms mix columns
1639+
if getattr(quantizer, "pre_quant_scale", None) is not None:
1640+
return None, False
1641+
if getattr(quantizer, "rotate_is_enabled", False):
1642+
return None, False
1643+
1644+
# Need calibrated amax
1645+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
1646+
return None, False
1647+
1648+
num_bits = quantizer._num_bits
1649+
unsigned = getattr(quantizer, "_unsigned", False)
1650+
narrow_range = getattr(quantizer, "_narrow_range", False)
1651+
max_bound = (2 ** (num_bits - 1 + int(unsigned))) - 1
1652+
min_bound = -max_bound + int(narrow_range)
1653+
1654+
amax = quantizer._amax.float()
1655+
out_features, in_features = weight_shape
1656+
1657+
# Determine quantization geometry from block_sizes
1658+
block_sizes = quantizer.block_sizes
1659+
group_size = None
1660+
if block_sizes is not None:
1661+
# Skip dynamic block quantization
1662+
if block_sizes.get("type", "static") == "dynamic":
1663+
return None, False
1664+
group_size = block_sizes.get(-1, None) or block_sizes.get(len(weight_shape) - 1, None)
1665+
1666+
if group_size is not None and group_size > 0:
1667+
# Per-group block quantization along last dim.
1668+
# After _setup_for_blockquant, weight is reshaped to (-1, group_size) with axis=(0,).
1669+
# amax shape: (out_features * n_groups, 1) where n_groups = in_features // group_size.
1670+
if in_features % group_size != 0:
1671+
return None, False # Padding case — fall back
1672+
1673+
n_groups = in_features // group_size
1674+
1675+
try:
1676+
# Reshape amax to (out_features, n_groups) for O(1) group lookup
1677+
amax_2d = amax.reshape(out_features, n_groups)
1678+
except RuntimeError:
1679+
return None, False
1680+
1681+
def _column_qdq_group(
1682+
col, col_idx, _a=amax_2d, _mx=max_bound, _mn=min_bound, _gs=group_size
1683+
):
1684+
col_scale = _mx / _a[:, col_idx // _gs].clamp(min=1e-12)
1685+
return torch.clamp(torch.round(col * col_scale), _mn, _mx) / col_scale
1686+
1687+
return _column_qdq_group, True
1688+
1689+
# Per-channel (axis != None) or per-tensor (axis == None)
1690+
axis = quantizer.axis
1691+
if axis is not None:
1692+
# Per-channel: amax has shape (out_features, 1) or similar
1693+
col_scale = max_bound / amax.reshape(-1).clamp(min=1e-12)
1694+
1695+
def _column_qdq_channel(col, col_idx, _s=col_scale, _mx=max_bound, _mn=min_bound):
1696+
return torch.clamp(torch.round(col * _s), _mn, _mx) / _s
1697+
1698+
return _column_qdq_channel, True
1699+
1700+
# Per-tensor: single scalar scale
1701+
scalar_scale = max_bound / amax.clamp(min=1e-12).item()
1702+
1703+
def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bound):
1704+
return torch.clamp(torch.round(col * _s), _mn, _mx) / _s
1705+
1706+
return _column_qdq_tensor, True
1707+
1708+
16121709
def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None):
16131710
"""Update module weights using GPTQ-style blockwise quantization.
16141711
@@ -1625,22 +1722,41 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None):
16251722
# Preprocess Hessian: handle dead neurons and add damping
16261723
h_inv = prepare_hessian_inverse(h, weight, percdamp)
16271724

1725+
# Try to build fast column-wise qdq (avoids quantizing the full matrix per column)
1726+
col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape)
1727+
16281728
# Process weights in blocks
16291729
for block_start in range(0, num_cols, block_size):
16301730
block_end = min(block_start + block_size, num_cols)
16311731
n_cols = block_end - block_start
1632-
wblk = weight.clone()
1633-
errs = torch.zeros_like(wblk[:, block_start:block_end])
16341732
h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end]
16351733

1636-
for i in range(n_cols):
1637-
w_ci = wblk[:, block_start + i]
1638-
d = h_inv_cho_blk[i, i]
1639-
qdq = module.weight_quantizer(wblk)
1640-
weight[:, block_start + i] = qdq[:, block_start + i]
1641-
err = (w_ci - qdq[:, block_start + i]) / d
1642-
wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
1643-
errs[:, i] = err
1734+
if col_qdq_supported:
1735+
# Fast path: clone only the block columns, quantize only per-column
1736+
wblk = weight[:, block_start:block_end].clone()
1737+
errs = torch.zeros_like(wblk)
1738+
1739+
for i in range(n_cols):
1740+
w_ci = wblk[:, i]
1741+
d = h_inv_cho_blk[i, i]
1742+
qdq_col = col_qdq_fn(w_ci, block_start + i)
1743+
weight[:, block_start + i] = qdq_col
1744+
err = (w_ci - qdq_col) / d
1745+
wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
1746+
errs[:, i] = err
1747+
else:
1748+
# Fallback: original full-matrix quantization path
1749+
wblk = weight.clone()
1750+
errs = torch.zeros_like(wblk[:, block_start:block_end])
1751+
1752+
for i in range(n_cols):
1753+
w_ci = wblk[:, block_start + i]
1754+
d = h_inv_cho_blk[i, i]
1755+
qdq = module.weight_quantizer(wblk)
1756+
weight[:, block_start + i] = qdq[:, block_start + i]
1757+
err = (w_ci - qdq[:, block_start + i]) / d
1758+
wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
1759+
errs[:, i] = err
16441760

16451761
# Propagate errors to remaining weights
16461762
weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1)
@@ -1844,7 +1960,7 @@ def _layer_forward_loop(m, _inputs=layer_inputs):
18441960
torch.cuda.empty_cache()
18451961
finally:
18461962
input_getter._unpatch_all_layers()
1847-
1963+
18481964
print_rank_0("Sequential calibration completed")
18491965

18501966

@@ -1969,9 +2085,9 @@ def hessian_forward(self, input, *args, **kwargs):
19692085
blockwise_weight_update(
19702086
module, hessian, block_size, percdamp, n_samples=state["n_samples"]
19712087
)
1972-
# Free memory
19732088
del hessian_state[module.name]
1974-
torch.cuda.empty_cache()
2089+
if torch.cuda.is_available():
2090+
torch.cuda.empty_cache()
19752091

19762092
torch.cuda.synchronize() if torch.cuda.is_available() else None
19772093
weight_update_time = time.time() - weight_update_start

0 commit comments

Comments
 (0)