Skip to content

Commit 3f2d7c0

Browse files
committed
gptq faster
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 806e8ac commit 3f2d7c0

3 files changed

Lines changed: 365 additions & 25 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 84 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,7 +1545,7 @@ def _print_relative_mse_error(
15451545
delta = q - w
15461546
mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6)
15471547
suffix = f", n_hessian_samples: {n_samples}" if n_samples is not None else ""
1548-
print(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}")
1548+
print_rank_0(f"[{module_name}] Relative MSE error: {mse.item():.2e}{suffix}")
15491549

15501550

15511551
def update_hessian(input, hessian, n_samples):
@@ -1604,7 +1604,7 @@ def prepare_hessian_inverse(h, weight, percdamp):
16041604
h = torch.cholesky_inverse(torch.linalg.cholesky(h))
16051605
h_inv = torch.linalg.cholesky(h, upper=True)
16061606
except (RuntimeError, torch.linalg.LinAlgError):
1607-
print("Warning: Hessian is not positive definite, using identity matrix")
1607+
print_rank_0("Warning: Hessian is not positive definite, using identity matrix")
16081608
h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype)
16091609
return h_inv
16101610

@@ -1706,37 +1706,104 @@ def _column_qdq_tensor(col, col_idx, _s=scalar_scale, _mx=max_bound, _mn=min_bou
17061706
return _column_qdq_tensor, True
17071707

17081708

1709+
def _can_use_fused_gptq(quantizer) -> bool:
1710+
"""Check whether the fused Triton GPTQ kernel can be used for *quantizer*."""
1711+
if not isinstance(quantizer, NVFP4StaticQuantizer):
1712+
return False
1713+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
1714+
return False
1715+
from modelopt.torch.quantization.triton import IS_AVAILABLE as _TRITON_OK
1716+
1717+
return _TRITON_OK
1718+
1719+
17091720
def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None):
17101721
"""Update module weights using GPTQ-style blockwise quantization.
17111722
1723+
Dispatches to one of three internal paths depending on quantizer type:
1724+
1725+
1. **Fused Triton** — for :class:`NVFP4StaticQuantizer` when Triton is
1726+
available. Runs the entire column loop in a single GPU kernel per
1727+
block (~130x faster than the unfused path on Blackwell GPUs).
1728+
2. **Column-QDQ** — for integer quantizers whose scale geometry allows
1729+
single-column fake-quant via :func:`_build_column_qdq`.
1730+
3. **Full-matrix fallback** — calls the quantizer on the full weight matrix
1731+
each column (slowest, but always correct).
1732+
17121733
Args:
1713-
module: Neural network module with weight and weight_quantizer
1714-
H: Hessian matrix (d x d)
1715-
block_size: Size of blocks to process at once
1716-
percdamp: Damping percentage for Hessian diagonal
1717-
n_samples: Number of Hessian samples for logging (optional)
1734+
module: Neural network module with ``weight`` and ``weight_quantizer``.
1735+
h: Hessian matrix of shape ``(d, d)``.
1736+
block_size: Number of columns processed per block.
1737+
percdamp: Damping as a fraction of the mean Hessian diagonal.
1738+
n_samples: Number of Hessian samples (used only for logging).
17181739
"""
17191740
weight = module.weight.data.float().clone()
1720-
_, num_cols = weight.shape
1741+
num_rows, num_cols = weight.shape
17211742

1722-
# Preprocess Hessian: handle dead neurons and add damping
17231743
h_inv = prepare_hessian_inverse(h, weight, percdamp)
17241744

1725-
# Try to build fast column-wise qdq (avoids quantizing the full matrix per column)
1726-
col_qdq_fn, col_qdq_supported = _build_column_qdq(module.weight_quantizer, weight.shape)
1745+
quantizer = module.weight_quantizer
1746+
if _can_use_fused_gptq(quantizer):
1747+
_blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size)
1748+
else:
1749+
col_qdq_fn, col_qdq_supported = _build_column_qdq(quantizer, weight.shape)
1750+
_blockwise_weight_update_unfused(
1751+
weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported
1752+
)
1753+
1754+
_print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples)
1755+
module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype)
1756+
1757+
1758+
def _blockwise_weight_update_fused(weight, h_inv, quantizer, num_rows, num_cols, block_size):
1759+
"""Fused Triton path for NVFP4: one kernel launch per block."""
1760+
from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block
1761+
1762+
group_size = quantizer.block_sizes.get(-1, None) or quantizer.block_sizes.get(1, None)
1763+
num_groups = math.ceil(num_cols / group_size)
1764+
amax_grouped = quantizer._amax.float().reshape(num_rows, num_groups).contiguous()
1765+
global_amax = quantizer.global_amax.float()
17271766

1728-
# Process weights in blocks
17291767
for block_start in range(0, num_cols, block_size):
17301768
block_end = min(block_start + block_size, num_cols)
1731-
n_cols = block_end - block_start
1769+
n_cols_blk = block_end - block_start
1770+
1771+
w_block = weight[:, block_start:block_end].clone().contiguous()
1772+
h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end].contiguous()
1773+
1774+
qw_block, err_block = gptq_fused_block(
1775+
w_block,
1776+
amax_grouped,
1777+
global_amax,
1778+
h_inv_cho_blk,
1779+
group_size,
1780+
block_start,
1781+
n_cols_blk,
1782+
)
1783+
1784+
weight[:, block_start:block_end] = qw_block
1785+
if block_end < num_cols:
1786+
weight[:, block_end:].addmm_(
1787+
err_block[:, :n_cols_blk],
1788+
h_inv[block_start:block_end, block_end:],
1789+
alpha=-1,
1790+
)
1791+
1792+
1793+
def _blockwise_weight_update_unfused(
1794+
weight, h_inv, quantizer, num_cols, block_size, col_qdq_fn, col_qdq_supported
1795+
):
1796+
"""Column-QDQ or full-matrix fallback for non-NVFP4 quantizers."""
1797+
for block_start in range(0, num_cols, block_size):
1798+
block_end = min(block_start + block_size, num_cols)
1799+
n_cols_blk = block_end - block_start
17321800
h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end]
17331801

17341802
if col_qdq_supported:
1735-
# Fast path: clone only the block columns, quantize only per-column
17361803
wblk = weight[:, block_start:block_end].clone()
17371804
errs = torch.zeros_like(wblk)
17381805

1739-
for i in range(n_cols):
1806+
for i in range(n_cols_blk):
17401807
w_ci = wblk[:, i]
17411808
d = h_inv_cho_blk[i, i]
17421809
qdq_col = col_qdq_fn(w_ci, block_start + i)
@@ -1745,27 +1812,20 @@ def blockwise_weight_update(module, h, block_size, percdamp, n_samples=None):
17451812
wblk[:, i:].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
17461813
errs[:, i] = err
17471814
else:
1748-
# Fallback: original full-matrix quantization path
17491815
wblk = weight.clone()
17501816
errs = torch.zeros_like(wblk[:, block_start:block_end])
17511817

1752-
for i in range(n_cols):
1818+
for i in range(n_cols_blk):
17531819
w_ci = wblk[:, block_start + i]
17541820
d = h_inv_cho_blk[i, i]
1755-
qdq = module.weight_quantizer(wblk)
1821+
qdq = quantizer(wblk)
17561822
weight[:, block_start + i] = qdq[:, block_start + i]
17571823
err = (w_ci - qdq[:, block_start + i]) / d
17581824
wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
17591825
errs[:, i] = err
17601826

1761-
# Propagate errors to remaining weights
17621827
weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1)
17631828

1764-
# Print relative mse error
1765-
_print_relative_mse_error(weight, module.weight.float(), h, module.name, n_samples)
1766-
# Update module weights
1767-
module.weight.data = weight.reshape(module.weight.shape).to(module.weight.data.dtype)
1768-
17691829

17701830
def gptq_lite(
17711831
model: nn.Module,
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Fused Triton kernel for the GPTQ blockwise weight-update inner loop.
17+
18+
The standard GPTQ inner loop launches ~10-15 CUDA kernels per column
19+
(amax lookup, FP4 quantization, error computation, rank-1 update).
20+
For ``block_size=128`` that is ~1 500 kernel launches per block, each with
21+
~5-10 us of launch overhead dominating actual compute.
22+
23+
This module fuses the entire inner loop into a **single** Triton kernel per
24+
block. Rows are independent and map to Triton programs; columns are processed
25+
sequentially inside each program so the rank-1 error update is carried forward
26+
without synchronisation.
27+
28+
Supported quantisation format: **NVFP4 static block quantisation** (two-level
29+
scaling with per-group amax and a global amax).
30+
"""
31+
32+
import torch
33+
import triton
34+
import triton.language as tl
35+
36+
__all__ = ["gptq_fused_block"]
37+
38+
# -- NVFP4 constants used by the kernel ------------------------------------
39+
# Maximum representable FP4-E2M1 value (1 + 1 + 0.5 = 6.0 when decoded via
40+
# the standard E2M1 table: {0, 0.5, 1, 1.5, 2, 3, 4, 6}).
41+
_FP4_MAX = 6.0
42+
# FP8-E4M3 has max representable value 448.
43+
_FP8_E4M3_MAX = 448.0
44+
45+
46+
@triton.jit
47+
def _gptq_fused_block_kernel(
48+
w_ptr, # [num_rows, BLOCK_SIZE] working weight block (in-place)
49+
qw_ptr, # [num_rows, BLOCK_SIZE] output: quantized weights
50+
err_ptr, # [num_rows, BLOCK_SIZE] output: quantization errors
51+
amax_ptr, # [num_rows, num_groups] per-group amax, row-major
52+
global_amax_ptr, # scalar float32 on device
53+
hinv_ptr, # [BLOCK_SIZE, BLOCK_SIZE] upper Cholesky of H^{-1}
54+
num_rows,
55+
num_groups,
56+
group_size: tl.constexpr,
57+
block_start, # column offset of this block in the full weight matrix
58+
n_cols, # actual columns in this block (may be < BLOCK_SIZE)
59+
BLOCK_SIZE: tl.constexpr,
60+
):
61+
"""One program per row; sequentially quantizes columns, propagating errors."""
62+
row = tl.program_id(0)
63+
if row >= num_rows:
64+
return
65+
66+
# Base pointers for this row
67+
w_base = w_ptr + row * BLOCK_SIZE
68+
qw_base = qw_ptr + row * BLOCK_SIZE
69+
err_base = err_ptr + row * BLOCK_SIZE
70+
amax_row_base = amax_ptr + row * num_groups
71+
72+
# Pre-compute global FP8 scale factors (constant across columns)
73+
global_amax = tl.load(global_amax_ptr).to(tl.float32)
74+
global_scale = global_amax / 6.0 # _FP4_MAX
75+
fp8_inv_scale = tl.where(global_scale > 0.0, 1.0 / (448.0 / global_scale), 0.0)
76+
77+
j_range = tl.arange(0, BLOCK_SIZE)
78+
79+
for i in range(BLOCK_SIZE):
80+
wi = tl.load(w_base + i)
81+
82+
# -- Compute NVFP4 two-level scale for this column's group -----------
83+
col_idx = block_start + i
84+
group_idx = col_idx // group_size
85+
raw_amax = tl.load(amax_row_base + group_idx).to(tl.float32)
86+
raw_scale = raw_amax / 6.0 # _FP4_MAX
87+
88+
# FP8-quantize the block scale: scale * fp8_scale -> cast E4M3 -> back
89+
fp8_scale = tl.where(global_scale > 0.0, 448.0 / global_scale, 1.0)
90+
si = (raw_scale * fp8_scale).to(tl.float8e4nv).to(tl.float32) * fp8_inv_scale
91+
92+
# Guard: replace zero / nan / inf scale with 1.0
93+
# NOTE: ``si != si`` is the standard NaN check in Triton (no math.isnan).
94+
si_safe = tl.where(
95+
(si == 0.0) | (si != si) | (tl.abs(si) == float("inf")), # noqa: PLR0124
96+
1.0,
97+
si,
98+
)
99+
100+
# -- FP4-E2M1 fake quantization (nearest-round to 8 levels) ----------
101+
abs_scaled = tl.abs(wi) / si_safe
102+
q_val = tl.where(
103+
abs_scaled <= 0.25,
104+
0.0,
105+
tl.where(
106+
abs_scaled < 0.75,
107+
0.5,
108+
tl.where(
109+
abs_scaled <= 1.25,
110+
1.0,
111+
tl.where(
112+
abs_scaled < 1.75,
113+
1.5,
114+
tl.where(
115+
abs_scaled <= 2.5,
116+
2.0,
117+
tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)),
118+
),
119+
),
120+
),
121+
),
122+
)
123+
124+
qi = q_val * si_safe * tl.where(wi >= 0.0, 1.0, -1.0)
125+
tl.store(qw_base + i, qi)
126+
127+
# -- GPTQ error and rank-1 update ------------------------------------
128+
di = tl.load(hinv_ptr + i * BLOCK_SIZE + i)
129+
err_i = (wi - qi) / di
130+
tl.store(err_base + i, err_i)
131+
132+
j_mask = (j_range > i) & (j_range < n_cols)
133+
hinv_row = tl.load(hinv_ptr + i * BLOCK_SIZE + j_range, mask=j_mask, other=0.0)
134+
w_rem = tl.load(w_base + j_range, mask=j_mask, other=0.0)
135+
w_rem = w_rem - err_i * hinv_row
136+
tl.store(w_base + j_range, w_rem, mask=j_mask)
137+
138+
139+
def gptq_fused_block(
140+
w_block: torch.Tensor,
141+
amax_grouped: torch.Tensor,
142+
global_amax: torch.Tensor,
143+
h_inv_cho_blk: torch.Tensor,
144+
group_size: int,
145+
block_start: int,
146+
n_cols: int,
147+
) -> tuple[torch.Tensor, torch.Tensor]:
148+
"""Run the GPTQ column loop for one block in a single Triton kernel launch.
149+
150+
Args:
151+
w_block: Working weight block of shape ``[num_rows, block_size]`` (will be cloned).
152+
amax_grouped: Per-group amax of shape ``[num_rows, num_groups]``.
153+
global_amax: Scalar tensor with the global amax.
154+
h_inv_cho_blk: Upper Cholesky factor of H^{-1}, shape ``[block_size, block_size]``.
155+
group_size: NVFP4 quantization group size (typically 16).
156+
block_start: Column offset of this block in the full weight matrix.
157+
n_cols: Actual number of columns in this block (``<= block_size``).
158+
159+
Returns:
160+
Tuple of ``(qw_block, err_block)`` each of shape ``[num_rows, block_size]``.
161+
"""
162+
num_rows, block_size = w_block.shape
163+
num_groups = amax_grouped.shape[1]
164+
165+
w_block = w_block.contiguous()
166+
amax_grouped = amax_grouped.contiguous()
167+
h_inv_cho_blk = h_inv_cho_blk.contiguous()
168+
169+
qw_block = torch.empty_like(w_block)
170+
err_block = torch.empty_like(w_block)
171+
172+
grid = (num_rows,)
173+
with torch.cuda.device(w_block.device):
174+
_gptq_fused_block_kernel[grid](
175+
w_block,
176+
qw_block,
177+
err_block,
178+
amax_grouped,
179+
global_amax,
180+
h_inv_cho_blk,
181+
num_rows,
182+
num_groups,
183+
group_size,
184+
block_start,
185+
n_cols,
186+
BLOCK_SIZE=block_size,
187+
)
188+
189+
return qw_block, err_block

0 commit comments

Comments
 (0)