Skip to content

Commit b331ed2

Browse files
committed
add config for fused kernel; qdq inside kernel; even more extraction
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent 08a7392 commit b331ed2

7 files changed

Lines changed: 164 additions & 120 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,12 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig):
15491549
description="""The block size for GPTQ weight update, which must be a multiple of the
15501550
group_size used in the quantization.""",
15511551
)
1552+
fused: bool = ModeloptField(
1553+
default=False,
1554+
title="Use fused Triton kernel for GPTQ.",
1555+
description="""When True, use a fused Triton kernel that combines quantization and
1556+
per-column error propagation into one launch per GPTQ block.""",
1557+
)
15521558

15531559

15541560
QuantizeQuantCfgType = list[QuantizerCfgEntry]

modelopt/torch/quantization/model_calib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,7 @@ def gptq(
16981698
forward_loop: ForwardLoop,
16991699
perc_damp: float = 0.01,
17001700
block_size: int = 128,
1701+
fused: bool = False,
17011702
):
17021703
"""GPTQ quantization.
17031704
@@ -1723,6 +1724,7 @@ def gptq(
17231724
forward_loop: Callable that replays calibration inputs through *model*.
17241725
perc_damp: Percentage of avg Hessian diagonal for damping (default: 0.01).
17251726
block_size: Block size for GPTQ weight update.
1727+
fused: If True, use fused Triton kernel for NVFP4 static quantization.
17261728
"""
17271729
total_start = time.time()
17281730

@@ -1745,7 +1747,7 @@ def _make_gptq_handle(name, m):
17451747
cls = GPTQHelper
17461748
else:
17471749
cls = _GPTQ_HELPER_REGISTRY.get(backend, GPTQHelper)
1748-
return cls(m, name, offload_to_cpu=True)
1750+
return cls(m, name, offload_to_cpu=True, fused=fused)
17491751

17501752
gptq_handles = {name: _make_gptq_handle(name, m) for name, m in quantized_layers}
17511753
for handle in gptq_handles.values():

modelopt/torch/quantization/triton/fp4_kernel_hopper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import triton.language as tl
2525

2626
from .fp4_kernel import _torch_dtype_to_tl
27-
from .nvfp4_quant import fp4_round_magnitude
27+
from .nvfp4_quant import fp4_round_magnitude, fp8_quantize_scale
2828

2929
__all__ = ["fp4_fake_quant_block"]
3030

@@ -80,9 +80,7 @@ def fp4_fake_quant_kernel(
8080

8181
block_max = tl.max(x_abs, axis=2, keep_dims=True)
8282

83-
block_max_scaled = block_max / (6.0 * global_scale_safe)
84-
block_max_scaled = tl.minimum(block_max_scaled, 448.0)
85-
block_max_quant = block_max_scaled.to(tl.float8e4nv).to(tl.float32) * global_scale
83+
block_max_quant = fp8_quantize_scale(block_max, global_scale_safe)
8684
block_max_quant = tl.where(block_max_quant >= 1e-5, block_max_quant, 1.0)
8785

8886
block_max_quant_broadcast = tl.broadcast_to(

modelopt/torch/quantization/triton/gptq_fused_kernel.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,28 @@
1515

1616
"""Fused Triton kernels for GPTQ blockwise weight-update.
1717
18-
A kernel for scalar (NVFP4) quantization.
19-
Each kernel fuses quantization + per-column GPTQ error propagation into
18+
A kernel for scalar (NVFP4) quantization with inline two-level scale computation.
19+
Fuses scale computation + quantization + per-column GPTQ error propagation into
2020
one launch per GPTQ block, avoiding the Python-level per-column loop.
2121
22-
Architecture (both kernels):
22+
Architecture:
2323
- One Triton program per output row.
2424
- ``w_full [BLOCK_SIZE]`` register tensor holds working weights.
25-
- Per-column error propagation: ``w_full -= err * h_inv_row``.
26-
27-
Scalar kernel (``_gptq_scalar_kernel``):
28-
- Calls ``nvfp4_scalar_quant()`` from ``nvfp4_quant.py`` per column.
25+
- Per-column: calls ``nvfp4_scalar_qdq()`` for FP4 QDQ with inline scale
26+
computation, then propagates error via ``w_full -= err * h_inv_row``.
2927
"""
3028

3129
import torch
3230
import triton
3331
import triton.language as tl
3432

35-
from .nvfp4_quant import nvfp4_scalar_quant
33+
from .nvfp4_quant import nvfp4_scalar_qdq
3634

3735
__all__ = ["gptq_fused_block_scalar"]
3836

3937

4038
# ---------------------------------------------------------------------------
41-
# Scalar kernel — NVFP4 quantization + error propagation
39+
# Scalar kernel — NVFP4 QDQ + error propagation
4240
# ---------------------------------------------------------------------------
4341

4442

@@ -47,10 +45,11 @@ def _gptq_scalar_kernel(
4745
w_ptr,
4846
qw_ptr,
4947
err_ptr,
50-
scales_ptr,
48+
amax_ptr,
49+
global_scale,
5150
hinv_ptr,
5251
num_rows,
53-
n_scale_blocks,
52+
n_amax_blocks,
5453
quant_block_size,
5554
block_start,
5655
BLOCK_SIZE: tl.constexpr,
@@ -62,19 +61,20 @@ def _gptq_scalar_kernel(
6261
w_base = w_ptr + row * BLOCK_SIZE
6362
qw_base = qw_ptr + row * BLOCK_SIZE
6463
err_base = err_ptr + row * BLOCK_SIZE
65-
scales_base = scales_ptr + row * n_scale_blocks
64+
amax_base = amax_ptr + row * n_amax_blocks
6665

6766
j_range = tl.arange(0, BLOCK_SIZE)
6867
w_full = tl.load(w_base + j_range)
6968

7069
for col in range(0, BLOCK_SIZE, 1):
71-
scale = tl.load(scales_base + (block_start + col) // quant_block_size)
70+
block_amax = tl.load(amax_base + (block_start + col) // quant_block_size)
7271

7372
w_scalar = tl.sum(tl.where(j_range == col, w_full, 0.0))
7473
q_scalar = tl.sum(
75-
nvfp4_scalar_quant(
74+
nvfp4_scalar_qdq(
7675
tl.full([1], w_scalar, dtype=tl.float32),
77-
scale,
76+
block_amax,
77+
global_scale,
7878
1,
7979
)
8080
)
@@ -91,16 +91,22 @@ def _gptq_scalar_kernel(
9191

9292
def gptq_fused_block_scalar(
9393
w_block: torch.Tensor,
94-
scales_2d: torch.Tensor,
94+
block_amax: torch.Tensor,
95+
global_scale: float,
9596
h_inv_cho_blk: torch.Tensor,
9697
quant_block_size: int,
9798
block_start: int,
9899
) -> tuple[torch.Tensor, torch.Tensor]:
99100
"""Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch.
100101
102+
Computes FP8-quantized scales from per-block amax inline via
103+
:func:`nvfp4_scalar_qdq`, then performs NVFP4 fake quantization and
104+
GPTQ error propagation per column.
105+
101106
Args:
102107
w_block: Working weights ``[num_rows, block_size]`` (float32).
103-
scales_2d: Pre-computed scales ``[num_rows, n_scale_blocks]`` (float32).
108+
block_amax: Per-block amax values ``[num_rows, n_amax_blocks]`` (float32).
109+
global_scale: Pre-computed ``global_amax / (6.0 * 448.0)`` (scalar).
104110
h_inv_cho_blk: Block of upper-Cholesky inverse Hessian ``[block_size, block_size]``.
105111
quant_block_size: Number of elements sharing one scale factor.
106112
block_start: Column offset of this block in the full weight matrix.
@@ -117,10 +123,11 @@ def gptq_fused_block_scalar(
117123
w_block.contiguous(),
118124
qw_block,
119125
err_block,
120-
scales_2d.contiguous(),
126+
block_amax.contiguous(),
127+
global_scale,
121128
h_inv_cho_blk.contiguous(),
122129
num_rows,
123-
scales_2d.shape[1],
130+
block_amax.shape[1],
124131
quant_block_size,
125132
block_start,
126133
BLOCK_SIZE=block_size,

modelopt/torch/quantization/triton/nvfp4_quant.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,52 @@ def nvfp4_scalar_quant(
9393
x_rescaled = q_val * scale_safe
9494
x_quant = tl.where(x >= 0, x_rescaled, -x_rescaled)
9595
return x_quant
96+
97+
98+
@triton.jit
99+
def fp8_quantize_scale(block_amax, global_scale):
100+
"""FP8 E4M3 fake-quantize the per-block NVFP4 scale.
101+
102+
Computes ``scale = block_amax / 6.0``, then round-trips it through
103+
FP8 E4M3 using ``global_scale`` for the second-level scaling.
104+
105+
Works with any tensor shape (scalar, 1-D, or higher) since all ops
106+
are element-wise.
107+
108+
Args:
109+
block_amax: Per-block amax value(s).
110+
global_scale: Pre-computed ``global_amax / (6.0 * 448.0)``.
111+
112+
Returns:
113+
FP8-quantized per-block scale(s), same shape as ``block_amax``.
114+
"""
115+
FP8_E4M3_MAX: tl.constexpr = 448.0
116+
scale_in_fp8_range = block_amax / (6.0 * global_scale)
117+
scale_clamped = tl.minimum(scale_in_fp8_range, FP8_E4M3_MAX)
118+
return scale_clamped.to(tl.float8e4nv).to(tl.float32) * global_scale
119+
120+
121+
@triton.jit
122+
def nvfp4_scalar_qdq(
123+
x, # [N] float32, already loaded
124+
block_amax, # float32 scalar: per-block amax
125+
global_scale, # float32 scalar: pre-computed global_amax / (6.0 * 448.0)
126+
N: tl.constexpr,
127+
):
128+
"""NVFP4 scalar fake quantization with inline two-level scale computation.
129+
130+
Computes the per-block FP8-quantized scale from ``block_amax`` via
131+
:func:`fp8_quantize_scale`, then quantizes each element to the nearest
132+
FP4 (E2M1) value.
133+
134+
Args:
135+
x: [N] float32 tensor of values to quantize.
136+
block_amax: Per-block amax (absolute maximum of the block).
137+
global_scale: Pre-computed ``global_amax / (6.0 * 448.0)``.
138+
N: Compile-time number of elements.
139+
140+
Returns:
141+
x_quant: [N] float32, fake-quantized values.
142+
"""
143+
scale = fp8_quantize_scale(block_amax, global_scale)
144+
return nvfp4_scalar_quant(x, scale, N)

modelopt/torch/quantization/utils/calib_utils.py

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,11 @@ class GPTQHelper:
126126

127127
CACHE_NAME = "_forward_no_gptq_hessian"
128128

129-
def __init__(self, module, name, offload_to_cpu=False):
129+
def __init__(self, module, name, offload_to_cpu=False, fused=False):
130130
"""Initialize GPTQHelper with module state and Hessian storage."""
131131
self.module = module
132132
self.name = name
133+
self.fused = fused
133134
in_features = module.weight.shape[-1]
134135
device = module.weight.device
135136
if device.type == "meta" or (offload_to_cpu and get_used_gpu_mem_fraction(device) > 0.65):
@@ -195,23 +196,35 @@ def _prepare_hessian_inverse(self, hessian, perc_damp):
195196
self.h_inv = compute_hessian_inverse(hessian, self.weight, perc_damp)
196197

197198
def _blockwise_update(self, block_size):
198-
"""Column-wise GPTQ update using full-matrix QDQ.
199+
"""Column-wise GPTQ update.
199200
200-
Delegates to :func:`gptq_blockwise_update` with the module's weight quantizer.
201+
When ``self.fused`` is True and the weight quantizer is an
202+
``NVFP4StaticQuantizer``, uses :func:`gptq_blockwise_update_fused_scalar`
203+
(a fused Triton kernel). Otherwise falls back to
204+
:func:`gptq_blockwise_update` (unfused column-by-column loop).
201205
"""
202206
assert self.weight is not None and self.h_inv is not None, (
203207
"_blockwise_update called before _prepare_hessian_inverse()"
204208
)
205209
quantizer = self.module.weight_quantizer
206-
block_sizes = getattr(quantizer, "block_sizes", None)
207-
if block_sizes is not None:
208-
group_size = block_sizes.get(-1)
209-
if group_size is not None and block_size % group_size != 0:
210+
211+
if self.fused and getattr(quantizer, "_is_nvfp4_static_quantizer", False):
212+
block_sizes = quantizer.block_sizes
213+
quant_block_size = block_sizes.get(-1) or block_sizes.get(1)
214+
if quant_block_size is not None and block_size % quant_block_size != 0:
210215
raise ValueError(
211216
f"GPTQ block_size ({block_size}) must be divisible by the quantizer"
212-
f" group_size ({group_size})"
217+
f" group_size ({quant_block_size})"
213218
)
214-
gptq_blockwise_update(self.weight, self.h_inv, block_size, quantizer)
219+
out_features, num_cols = self.weight.shape
220+
n_blocks = num_cols // quant_block_size
221+
block_amax = quantizer.amax.reshape(out_features, n_blocks).float()
222+
global_scale = quantizer.global_amax.float().item() / (6.0 * 448.0)
223+
gptq_blockwise_update_fused_scalar(
224+
self.weight, block_amax, global_scale, self.h_inv, block_size, quant_block_size
225+
)
226+
else:
227+
gptq_blockwise_update(self.weight, self.h_inv, block_size, quantizer)
215228

216229
def _print_mse_error(self, hessian):
217230
"""Log Hessian-weighted relative MSE between ``self.weight`` and original weights."""
@@ -260,17 +273,20 @@ def gptq_blockwise_update(weight, h_inv, block_size, quantize_fn):
260273
weight[:, block_end:].addmm_(errs, h_inv[block_start:block_end, block_end:], alpha=-1)
261274

262275

263-
def gptq_blockwise_update_fused_scalar(weight, scales_2d, h_inv, block_size, quant_block_size):
276+
def gptq_blockwise_update_fused_scalar(
277+
weight, block_amax, global_scale, h_inv, block_size, quant_block_size
278+
):
264279
"""Fused GPTQ blockwise update for NVFP4 scalar quantization.
265280
266-
Uses a fused Triton kernel that combines quantization and per-column
267-
error propagation into one launch per GPTQ block, avoiding the
268-
Python-level per-column loop in :func:`gptq_blockwise_update`.
281+
Uses a fused Triton kernel that combines scale computation, quantization,
282+
and per-column error propagation into one launch per GPTQ block, avoiding
283+
the Python-level per-column loop in :func:`gptq_blockwise_update`.
269284
270285
Args:
271286
weight: Weight tensor ``[out_features, in_features]``, modified **in-place**
272287
with fake-quantized values.
273-
scales_2d: Pre-computed per-block scales ``[out_features, n_scale_blocks]``.
288+
block_amax: Per-block amax values ``[out_features, n_amax_blocks]``.
289+
global_scale: Pre-computed ``global_amax / (6.0 * 448.0)`` (scalar).
274290
h_inv: Upper-triangular Cholesky factor of the damped inverse Hessian
275291
``[in_features, in_features]``.
276292
block_size: Number of columns to process per GPTQ block.
@@ -283,7 +299,8 @@ def gptq_blockwise_update_fused_scalar(weight, scales_2d, h_inv, block_size, qua
283299
be = min(bs + block_size, num_cols)
284300
qw, err = gptq_fused_block_scalar(
285301
weight[:, bs:be].clone().contiguous(),
286-
scales_2d,
302+
block_amax,
303+
global_scale,
287304
h_inv[bs:be, bs:be].contiguous(),
288305
quant_block_size,
289306
bs,
@@ -293,44 +310,6 @@ def gptq_blockwise_update_fused_scalar(weight, scales_2d, h_inv, block_size, qua
293310
weight[:, be:].addmm_(err, h_inv[bs:be, be:], alpha=-1)
294311

295312

296-
class FusedScalarGPTQHelper(GPTQHelper):
297-
"""GPTQHelper using the fused Triton kernel for NVFP4 scalar quantization.
298-
299-
Overrides :meth:`_blockwise_update` to extract pre-computed scales from the
300-
``NVFP4StaticQuantizer`` and delegate to :func:`gptq_blockwise_update_fused_scalar`.
301-
"""
302-
303-
def _blockwise_update(self, block_size):
304-
"""Fused GPTQ using Triton kernel for NVFP4 scalar quantization."""
305-
assert self.weight is not None and self.h_inv is not None, (
306-
"_blockwise_update called before _prepare_hessian_inverse()"
307-
)
308-
from modelopt.torch.quantization.triton.fp4_kernel import compute_fp4_scales
309-
310-
quantizer = self.module.weight_quantizer
311-
block_sizes = getattr(quantizer, "block_sizes", None)
312-
quant_block_size = None
313-
if block_sizes is not None:
314-
quant_block_size = block_sizes.get(-1) or block_sizes.get(1)
315-
316-
if quant_block_size is not None and block_size % quant_block_size != 0:
317-
raise ValueError(
318-
f"GPTQ block_size ({block_size}) must be divisible by the quantizer"
319-
f" group_size ({quant_block_size})"
320-
)
321-
322-
out_features, num_cols = self.weight.shape
323-
n_blocks = num_cols // quant_block_size
324-
325-
# Pre-compute scales from the calibrated amax (frozen during GPTQ).
326-
amax = quantizer.amax.reshape(out_features, n_blocks)
327-
scales_2d = compute_fp4_scales(amax, quantizer.global_amax, quantize_block_scales=True)
328-
329-
gptq_blockwise_update_fused_scalar(
330-
self.weight, scales_2d, self.h_inv, block_size, quant_block_size
331-
)
332-
333-
334313
_GPTQ_HELPER_REGISTRY: dict[str, type[GPTQHelper]] = {}
335314

336315

@@ -342,7 +321,3 @@ def register_gptq_helper(backend: str, factory: type[GPTQHelper]) -> None:
342321
construct ``factory`` instead of the default ``GPTQHelper``.
343322
"""
344323
_GPTQ_HELPER_REGISTRY[backend] = factory
345-
346-
347-
# Built-in registrations
348-
register_gptq_helper("fused_gptq_nvfp4", FusedScalarGPTQHelper)

0 commit comments

Comments
 (0)