Skip to content

Commit 08a7392

Browse files
committed
integration
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent d2c32fc commit 08a7392

2 files changed

Lines changed: 170 additions & 106 deletions

File tree

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 153 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,42 @@ def update_hessian(input, hessian, n_samples):
7474
return hessian, n_samples
7575

7676

77+
def compute_hessian_inverse(hessian, weight, perc_damp):
78+
"""Compute damped upper-Cholesky inverse Hessian.
79+
80+
Dead-neuron columns (all-zero in ``weight``) are zeroed in the
81+
Hessian before inversion, matching the FP-Quant reference:
82+
https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200
83+
84+
Args:
85+
hessian: Hessian matrix ``[in_features, in_features]``.
86+
weight: Weight matrix ``[out_features, in_features]`` for dead-neuron detection.
87+
perc_damp: Percentage of average Hessian diagonal for damping.
88+
89+
Returns:
90+
Upper-triangular Cholesky factor of the damped inverse Hessian
91+
``[in_features, in_features]``. Falls back to the identity matrix
92+
when the Hessian is not positive definite.
93+
"""
94+
h = hessian.clone()
95+
zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1)
96+
97+
h[zero_cols, :] = 0
98+
h[:, zero_cols] = 0
99+
h[zero_cols, zero_cols] = 1
100+
101+
damp = perc_damp * torch.mean(torch.diag(h))
102+
diag_indices = torch.arange(h.shape[0], device=h.device)
103+
h[diag_indices, diag_indices] += damp
104+
105+
try:
106+
h = torch.cholesky_inverse(torch.linalg.cholesky(h))
107+
return torch.linalg.cholesky(h, upper=True)
108+
except (RuntimeError, torch.linalg.LinAlgError):
109+
print_rank_0("Warning: Hessian is not positive definite, using identity matrix")
110+
return torch.eye(h.shape[0], device=h.device, dtype=h.dtype)
111+
112+
77113
class GPTQHelper:
78114
"""Encapsulates per-module GPTQ state and operations.
79115
@@ -154,38 +190,14 @@ def update_weights(self, block_size, perc_damp):
154190
# ------------------------------------------------------------------
155191

156192
def _prepare_hessian_inverse(self, hessian, perc_damp):
157-
"""Compute damped inverse Hessian and store as ``self.h_inv``.
158-
159-
Dead-neuron columns (all-zero in ``self.weight``) are zeroed in the
160-
Hessian before inversion, matching the FP-Quant reference:
161-
https://github.com/IST-DASLab/FP-Quant/blob/d2e3092f968262c4de5fb050e1aef568a280dadd/src/quantization/gptq.py#L200
162-
"""
193+
"""Compute damped inverse Hessian and store as ``self.h_inv``."""
163194
assert self.weight is not None, "_prepare_hessian_inverse called before update_weights()"
164-
h = hessian.clone()
165-
zero_cols = torch.nonzero(self.weight.eq(0).all(dim=0)).unsqueeze(-1)
166-
167-
h[zero_cols, :] = 0
168-
h[:, zero_cols] = 0
169-
h[zero_cols, zero_cols] = 1
170-
171-
damp = perc_damp * torch.mean(torch.diag(h))
172-
diag_indices = torch.arange(h.shape[0], device=h.device)
173-
h[diag_indices, diag_indices] += damp
174-
175-
try:
176-
h = torch.cholesky_inverse(torch.linalg.cholesky(h))
177-
self.h_inv = torch.linalg.cholesky(h, upper=True)
178-
except (RuntimeError, torch.linalg.LinAlgError):
179-
print_rank_0("Warning: Hessian is not positive definite, using identity matrix")
180-
self.h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype)
195+
self.h_inv = compute_hessian_inverse(hessian, self.weight, perc_damp)
181196

182197
def _blockwise_update(self, block_size):
183198
"""Column-wise GPTQ update using full-matrix QDQ.
184199
185-
For each column, quantizes the full weight matrix via the quantizer and
186-
extracts the quantized column. This is the standard GPTQ approach.
187-
188-
Reads/writes ``self.weight`` and ``self.h_inv`` in-place.
200+
Delegates to :func:`gptq_blockwise_update` with the module's weight quantizer.
189201
"""
190202
assert self.weight is not None and self.h_inv is not None, (
191203
"_blockwise_update called before _prepare_hessian_inverse()"
@@ -199,28 +211,7 @@ def _blockwise_update(self, block_size):
199211
f"GPTQ block_size ({block_size}) must be divisible by the quantizer"
200212
f" group_size ({group_size})"
201213
)
202-
num_cols = self.weight.shape[1]
203-
204-
for block_start in range(0, num_cols, block_size):
205-
block_end = min(block_start + block_size, num_cols)
206-
n_cols_blk = block_end - block_start
207-
h_inv_cho_blk = self.h_inv[block_start:block_end, block_start:block_end]
208-
209-
wblk = self.weight.clone()
210-
errs = torch.zeros_like(wblk[:, block_start:block_end])
211-
212-
for i in range(n_cols_blk):
213-
w_ci = wblk[:, block_start + i]
214-
d = h_inv_cho_blk[i, i]
215-
qdq = quantizer(wblk)
216-
self.weight[:, block_start + i] = qdq[:, block_start + i]
217-
err = (w_ci - qdq[:, block_start + i]) / d
218-
wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
219-
errs[:, i] = err
220-
221-
self.weight[:, block_end:].addmm_(
222-
errs, self.h_inv[block_start:block_end, block_end:], alpha=-1
223-
)
214+
gptq_blockwise_update(self.weight, self.h_inv, block_size, quantizer)
224215

225216
def _print_mse_error(self, hessian):
226217
"""Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
@@ -231,6 +222,115 @@ def _print_mse_error(self, hessian):
231222
print_rank_0(f"[{self.name}] Relative MSE error: {mse.item():.2e}{suffix}")
232223

233224

225+
def gptq_blockwise_update(weight, h_inv, block_size, quantize_fn):
226+
"""Column-wise GPTQ update using full-matrix fake quantization.
227+
228+
For each column, quantizes the full weight matrix via ``quantize_fn`` and
229+
extracts the quantized column. Error is propagated to remaining columns
230+
within the block and then to all subsequent columns via the inverse Hessian.
231+
232+
Args:
233+
weight: Weight tensor ``[out_features, in_features]``, modified **in-place**
234+
with fake-quantized values.
235+
h_inv: Upper-triangular Cholesky factor of the damped inverse Hessian
236+
``[in_features, in_features]``.
237+
block_size: Number of columns to process per GPTQ block.
238+
quantize_fn: Callable ``(weight) -> qdq_weight`` that fake-quantizes
239+
the full weight matrix.
240+
"""
241+
num_cols = weight.shape[1]
242+
243+
for block_start in range(0, num_cols, block_size):
244+
block_end = min(block_start + block_size, num_cols)
245+
n_cols_blk = block_end - block_start
246+
h_inv_cho_blk = h_inv[block_start:block_end, block_start:block_end]
247+
248+
wblk = weight.clone()
249+
errs = torch.zeros_like(weight[:, block_start:block_end])
250+
251+
for i in range(n_cols_blk):
252+
w_ci = wblk[:, block_start + i]
253+
d = h_inv_cho_blk[i, i]
254+
qdq = quantize_fn(wblk)
255+
weight[:, block_start + i] = qdq[:, block_start + i]
256+
err = (w_ci - qdq[:, block_start + i]) / d
257+
wblk[:, block_start + i : block_end].addr_(err, h_inv_cho_blk[i, i:], alpha=-1)
258+
errs[:, i] = err
259+
260+
weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1)
261+
262+
263+
def gptq_blockwise_update_fused_scalar(weight, scales_2d, h_inv, block_size, quant_block_size):
264+
"""Fused GPTQ blockwise update for NVFP4 scalar quantization.
265+
266+
Uses a fused Triton kernel that combines quantization and per-column
267+
error propagation into one launch per GPTQ block, avoiding the
268+
Python-level per-column loop in :func:`gptq_blockwise_update`.
269+
270+
Args:
271+
weight: Weight tensor ``[out_features, in_features]``, modified **in-place**
272+
with fake-quantized values.
273+
scales_2d: Pre-computed per-block scales ``[out_features, n_scale_blocks]``.
274+
h_inv: Upper-triangular Cholesky factor of the damped inverse Hessian
275+
``[in_features, in_features]``.
276+
block_size: Number of columns to process per GPTQ block.
277+
quant_block_size: Number of elements sharing one quantization scale factor.
278+
"""
279+
from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block_scalar
280+
281+
num_cols = weight.shape[1]
282+
for bs in range(0, num_cols, block_size):
283+
be = min(bs + block_size, num_cols)
284+
qw, err = gptq_fused_block_scalar(
285+
weight[:, bs:be].clone().contiguous(),
286+
scales_2d,
287+
h_inv[bs:be, bs:be].contiguous(),
288+
quant_block_size,
289+
bs,
290+
)
291+
weight[:, bs:be] = qw
292+
if be < num_cols:
293+
weight[:, be:].addmm_(err, h_inv[bs:be, be:], alpha=-1)
294+
295+
296+
class FusedScalarGPTQHelper(GPTQHelper):
297+
"""GPTQHelper using the fused Triton kernel for NVFP4 scalar quantization.
298+
299+
Overrides :meth:`_blockwise_update` to extract pre-computed scales from the
300+
``NVFP4StaticQuantizer`` and delegate to :func:`gptq_blockwise_update_fused_scalar`.
301+
"""
302+
303+
def _blockwise_update(self, block_size):
304+
"""Fused GPTQ using Triton kernel for NVFP4 scalar quantization."""
305+
assert self.weight is not None and self.h_inv is not None, (
306+
"_blockwise_update called before _prepare_hessian_inverse()"
307+
)
308+
from modelopt.torch.quantization.triton.fp4_kernel import compute_fp4_scales
309+
310+
quantizer = self.module.weight_quantizer
311+
block_sizes = getattr(quantizer, "block_sizes", None)
312+
quant_block_size = None
313+
if block_sizes is not None:
314+
quant_block_size = block_sizes.get(-1) or block_sizes.get(1)
315+
316+
if quant_block_size is not None and block_size % quant_block_size != 0:
317+
raise ValueError(
318+
f"GPTQ block_size ({block_size}) must be divisible by the quantizer"
319+
f" group_size ({quant_block_size})"
320+
)
321+
322+
out_features, num_cols = self.weight.shape
323+
n_blocks = num_cols // quant_block_size
324+
325+
# Pre-compute scales from the calibrated amax (frozen during GPTQ).
326+
amax = quantizer.amax.reshape(out_features, n_blocks)
327+
scales_2d = compute_fp4_scales(amax, quantizer.global_amax, quantize_block_scales=True)
328+
329+
gptq_blockwise_update_fused_scalar(
330+
self.weight, scales_2d, self.h_inv, block_size, quant_block_size
331+
)
332+
333+
234334
_GPTQ_HELPER_REGISTRY: dict[str, type[GPTQHelper]] = {}
235335

236336

@@ -242,3 +342,7 @@ def register_gptq_helper(backend: str, factory: type[GPTQHelper]) -> None:
242342
construct ``factory`` instead of the default ``GPTQHelper``.
243343
"""
244344
_GPTQ_HELPER_REGISTRY[backend] = factory
345+
346+
347+
# Built-in registrations
348+
register_gptq_helper("fused_gptq_nvfp4", FusedScalarGPTQHelper)

tests/gpu/torch/quantization/test_gptq.py

Lines changed: 17 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
2626
from modelopt.torch.quantization.model_calib import gptq
2727
from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor
28-
from modelopt.torch.quantization.utils.calib_utils import update_hessian
28+
from modelopt.torch.quantization.utils.calib_utils import (
29+
compute_hessian_inverse,
30+
gptq_blockwise_update,
31+
gptq_blockwise_update_fused_scalar,
32+
update_hessian,
33+
)
2934
from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader
3035

3136
RAND_SEED = 42
@@ -240,21 +245,6 @@ def test_gptq_e2e_flow(quant_cfg):
240245
# ---------------------------------------------------------------------------
241246

242247

243-
# TODO(shiychen): This should be extracted out from production code path
244-
def _compute_h_inv(hessian, weight, percdamp=0.01):
245-
"""Compute damped upper-Cholesky inverse Hessian."""
246-
h = hessian.clone()
247-
zero_cols = torch.nonzero(weight.eq(0).all(dim=0)).unsqueeze(-1)
248-
h[zero_cols, :] = 0
249-
h[:, zero_cols] = 0
250-
h[zero_cols, zero_cols] = 1
251-
damp = percdamp * torch.mean(torch.diag(h))
252-
diag_idx = torch.arange(h.shape[0], device=h.device)
253-
h[diag_idx, diag_idx] += damp
254-
h = torch.cholesky_inverse(torch.linalg.cholesky(h))
255-
return torch.linalg.cholesky(h, upper=True)
256-
257-
258248
def _make_nvfp4_test_data(quant_block_size, out_features, dim):
259249
"""Create weight, h_inv, and scales_2d for NVFP4 GPTQ tests."""
260250
from modelopt.torch.quantization.triton.fp4_kernel import compute_fp4_scales
@@ -268,14 +258,13 @@ def _make_nvfp4_test_data(quant_block_size, out_features, dim):
268258
hessian = torch.zeros(dim, dim, dtype=torch.float32)
269259
hessian, _ = update_hessian(inp, hessian, 0)
270260
hessian = hessian.to("cuda")
271-
h_inv = _compute_h_inv(hessian, weight)
261+
h_inv = compute_hessian_inverse(hessian, weight, perc_damp=0.01)
272262

273263
return weight, scales_2d, h_inv
274264

275265

276-
# TODO(shiychen): This should be extracted out from production code path
277266
def _run_unfused_gptq_nvfp4(weight, scales_2d, h_inv, gptq_block_size, quant_block_size):
278-
"""Unfused NVFP4 GPTQ using the production Triton FP4 kernel per column.
267+
"""Unfused NVFP4 GPTQ using the production blockwise update with Triton FP4 kernel.
279268
280269
Both fused and unfused use the same frozen pre-computed scales so the
281270
test verifies the fused kernel's correctness (not scale computation).
@@ -285,52 +274,23 @@ def _run_unfused_gptq_nvfp4(weight, scales_2d, h_inv, gptq_block_size, quant_blo
285274
out_features, num_cols = weight.shape
286275
n_blocks = num_cols // quant_block_size
287276
w = weight.float().clone()
288-
q = torch.zeros_like(w)
289277
# Recover amax from scales (scales = amax / 6.0, already FP8-quantized)
290278
amax_flat = (scales_2d * 6.0).reshape(out_features * n_blocks)
291279

292-
for i in range(0, num_cols, gptq_block_size):
293-
j_end = min(i + gptq_block_size, num_cols)
294-
e = torch.zeros(out_features, j_end - i, dtype=w.dtype, device=w.device)
295-
296-
for j in range(i, j_end):
297-
# Quantize full weight using production Triton FP4 kernel
298-
w_blocked = w.reshape(out_features * n_blocks, quant_block_size)
299-
qdq = static_blockwise_fp4_fake_quant(
300-
w_blocked,
301-
amax_flat,
302-
quantize_block_scales=False,
303-
).reshape(out_features, num_cols)
304-
q[:, j] = qdq[:, j]
305-
306-
err = (w[:, j] - q[:, j]) / h_inv[j, j]
307-
e[:, j - i] = err
308-
w[:, j:j_end] -= err.unsqueeze(1) * h_inv[j, j:j_end].unsqueeze(0)
280+
def quantize_fn(w_input):
281+
w_blocked = w_input.reshape(out_features * n_blocks, quant_block_size)
282+
return static_blockwise_fp4_fake_quant(
283+
w_blocked, amax_flat, quantize_block_scales=False
284+
).reshape(out_features, num_cols)
309285

310-
if j_end < num_cols:
311-
w[:, j_end:] -= e @ h_inv[i:j_end, j_end:]
312-
313-
return q
286+
gptq_blockwise_update(w, h_inv, gptq_block_size, quantize_fn)
287+
return w
314288

315289

316290
def _run_fused_gptq_nvfp4(weight, scales_2d, h_inv, gptq_block_size, quant_block_size):
317-
"""Fused Triton GPTQ for NVFP4."""
318-
from modelopt.torch.quantization.triton.gptq_fused_kernel import gptq_fused_block_scalar
319-
320-
dim = weight.shape[1]
291+
"""Fused Triton GPTQ for NVFP4 using the production fused update."""
321292
w = weight.float().clone()
322-
for bs in range(0, dim, gptq_block_size):
323-
be = min(bs + gptq_block_size, dim)
324-
qw, err = gptq_fused_block_scalar(
325-
w[:, bs:be].clone().contiguous(),
326-
scales_2d,
327-
h_inv[bs:be, bs:be].contiguous(),
328-
quant_block_size,
329-
bs,
330-
)
331-
w[:, bs:be] = qw
332-
if be < dim:
333-
w[:, be:].addmm_(err, h_inv[bs:be, be:], alpha=-1)
293+
gptq_blockwise_update_fused_scalar(w, scales_2d, h_inv, gptq_block_size, quant_block_size)
334294
return w
335295

336296

0 commit comments

Comments
 (0)