Skip to content

Commit fc14872

Browse files
committed
add gptq fused kernel
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent e4b054b commit fc14872

7 files changed

Lines changed: 500 additions & 79 deletions

File tree

modelopt/torch/quantization/triton/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# fp4_kernel works on any CUDA GPU with triton
3434
from .fp4_kernel import *
3535
from .fp8_kernel import *
36+
from .nvfp4_quant import *
3637

3738
# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
3839
if torch.cuda.get_device_capability() >= (8, 9):

modelopt/torch/quantization/triton/fp4_kernel.py

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import triton
2525
import triton.language as tl
2626

27-
__all__ = ["fp4_dequantize", "static_blockwise_fp4_fake_quant"]
27+
from .nvfp4_quant import nvfp4_scalar_quant
28+
29+
__all__ = ["compute_fp4_scales", "fp4_dequantize", "static_blockwise_fp4_fake_quant"]
2830

2931

3032
_TORCH_TO_TL_DTYPE = {
@@ -198,52 +200,47 @@ def static_blockwise_fp4_fake_quant_kernel(
198200
idx = block_offset + tl.arange(0, BLOCK_SIZE)
199201

200202
scale = tl.load(scale_ptr + pid).to(tl.float32)
201-
202203
x = tl.load(x_ptr + idx).to(tl.float32)
203204

204-
x_abs = tl.abs(x)
205-
# If scale is 0, inf, or nan, use 1.0 (matching CUDA kernel behavior)
206-
# Note: (x != x) checks if x is NaN per IEEE 754
207-
scale_safe = tl.where(
208-
(scale == 0) | (scale != scale) | (tl.abs(scale) == float("inf")), # noqa: PLR0124
209-
1.0,
210-
scale,
211-
)
212-
abs_scaled = x_abs / scale_safe
213-
214-
# FP4 values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0
215-
q_val = tl.where(
216-
abs_scaled <= 0.25,
217-
0.0,
218-
tl.where(
219-
abs_scaled < 0.75,
220-
0.5,
221-
tl.where(
222-
abs_scaled <= 1.25,
223-
1.0,
224-
tl.where(
225-
abs_scaled < 1.75,
226-
1.5,
227-
tl.where(
228-
abs_scaled <= 2.5,
229-
2.0,
230-
tl.where(
231-
abs_scaled < 3.5,
232-
3.0,
233-
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
234-
),
235-
),
236-
),
237-
),
238-
),
239-
)
240-
241-
x_rescaled = q_val * scale_safe
242-
x_quant = tl.where(x >= 0, x_rescaled, -x_rescaled)
205+
x_quant = nvfp4_scalar_quant(x, scale, BLOCK_SIZE)
243206

244207
tl.store(y_ptr + idx, x_quant.to(OUT_DTYPE))
245208

246209

210+
def compute_fp4_scales(
211+
amax: torch.Tensor,
212+
global_amax: torch.Tensor | None = None,
213+
quantize_block_scales: bool = True,
214+
) -> torch.Tensor:
215+
"""Compute per-block FP4 scales from amax values.
216+
217+
``scale = amax / 6.0``, optionally quantized to FP8 E4M3.
218+
219+
Args:
220+
amax: Per-block amax values (any shape).
221+
global_amax: Global amax for FP8 two-level scaling. Computed from *amax* if None.
222+
quantize_block_scales: If True, quantize scales to FP8 E4M3.
223+
224+
Returns:
225+
Per-block scales (same shape as *amax*), float32.
226+
"""
227+
amax = amax.float()
228+
scale = amax / 6.0 # FP4 max representable value is 6.0
229+
230+
if quantize_block_scales:
231+
from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl
232+
from modelopt.torch.quantization.utils import reduce_amax
233+
234+
if global_amax is None:
235+
global_amax = reduce_amax(amax, axis=None, keepdims=False, squeeze_scalar=True)
236+
237+
global_amax = global_amax.float()
238+
scale_fp8_quant_amax = global_amax / 6.0
239+
scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax)
240+
241+
return scale
242+
243+
247244
def static_blockwise_fp4_fake_quant(
248245
x: torch.Tensor,
249246
amax: torch.Tensor,
@@ -266,19 +263,7 @@ def static_blockwise_fp4_fake_quant(
266263
if out_dtype is None:
267264
out_dtype = x.dtype
268265

269-
amax = amax.float() # Requires to be in float32
270-
scale = amax / 6.0 # FP4 max representable value is 6.0
271-
272-
if quantize_block_scales:
273-
from modelopt.torch.quantization.tensor_quant import scaled_e4m3_impl
274-
from modelopt.torch.quantization.utils import reduce_amax
275-
276-
if global_amax is None:
277-
global_amax = reduce_amax(amax, axis=None, keepdims=False, squeeze_scalar=True)
278-
279-
global_amax = global_amax.float()
280-
scale_fp8_quant_amax = global_amax / 6.0
281-
scale = scaled_e4m3_impl(scale, scale_fp8_quant_amax)
266+
scale = compute_fp4_scales(amax, global_amax, quantize_block_scales)
282267

283268
x_flat = x.contiguous().view(-1)
284269
y_flat = torch.empty_like(x_flat, dtype=out_dtype)

modelopt/torch/quantization/triton/fp4_kernel_hopper.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import triton.language as tl
2525

2626
from .fp4_kernel import _torch_dtype_to_tl
27+
from .nvfp4_quant import fp4_round_magnitude
2728

2829
__all__ = ["fp4_fake_quant_block"]
2930

@@ -90,31 +91,7 @@ def fp4_fake_quant_kernel(
9091

9192
abs_scaled = x_abs / block_max_quant_broadcast
9293

93-
q_val = tl.where(
94-
abs_scaled <= 0.25,
95-
0.0,
96-
tl.where(
97-
abs_scaled < 0.75,
98-
0.5,
99-
tl.where(
100-
abs_scaled <= 1.25,
101-
1.0,
102-
tl.where(
103-
abs_scaled < 1.75,
104-
1.5,
105-
tl.where(
106-
abs_scaled <= 2.5,
107-
2.0,
108-
tl.where(
109-
abs_scaled < 3.5,
110-
3.0,
111-
tl.where(abs_scaled <= 5.0, 4.0, 6.0),
112-
),
113-
),
114-
),
115-
),
116-
),
117-
)
94+
q_val = fp4_round_magnitude(abs_scaled)
11895

11996
x_rescaled = q_val * block_max_quant_broadcast
12097
x_rescaled = tl.where(tile_reshaped >= 0, x_rescaled, -x_rescaled)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Fused Triton kernels for GPTQ blockwise weight-update.
17+
18+
A kernel for scalar (NVFP4) quantization.
19+
Each kernel fuses quantization + per-column GPTQ error propagation into
20+
one launch per GPTQ block, avoiding the Python-level per-column loop.
21+
22+
Architecture (both kernels):
23+
- One Triton program per output row.
24+
- ``w_full [BLOCK_SIZE]`` register tensor holds working weights.
25+
- Per-column error propagation: ``w_full -= err * h_inv_row``.
26+
27+
Scalar kernel (``_gptq_scalar_kernel``):
28+
- Calls ``nvfp4_scalar_quant()`` from ``nvfp4_quant.py`` per column.
29+
"""
30+
31+
import torch
32+
import triton
33+
import triton.language as tl
34+
35+
from .nvfp4_quant import nvfp4_scalar_quant
36+
37+
__all__ = ["gptq_fused_block_scalar"]
38+
39+
40+
# ---------------------------------------------------------------------------
41+
# Scalar kernel — NVFP4 quantization + error propagation
42+
# ---------------------------------------------------------------------------
43+
44+
45+
@triton.jit
46+
def _gptq_scalar_kernel(
47+
w_ptr,
48+
qw_ptr,
49+
err_ptr,
50+
scales_ptr,
51+
hinv_ptr,
52+
num_rows,
53+
n_scale_blocks,
54+
quant_block_size,
55+
block_start,
56+
n_cols,
57+
BLOCK_SIZE: tl.constexpr,
58+
):
59+
row = tl.program_id(0)
60+
if row >= num_rows:
61+
return
62+
63+
w_base = w_ptr + row * BLOCK_SIZE
64+
qw_base = qw_ptr + row * BLOCK_SIZE
65+
err_base = err_ptr + row * BLOCK_SIZE
66+
scales_base = scales_ptr + row * n_scale_blocks
67+
68+
j_range = tl.arange(0, BLOCK_SIZE)
69+
w_full = tl.load(w_base + j_range, mask=j_range < n_cols, other=0.0)
70+
71+
for col in range(0, BLOCK_SIZE, 1):
72+
scale = tl.load(scales_base + (block_start + col) // quant_block_size)
73+
74+
w_scalar = tl.sum(tl.where(j_range == col, w_full, 0.0))
75+
q_scalar = tl.sum(
76+
nvfp4_scalar_quant(
77+
tl.full([1], w_scalar, dtype=tl.float32),
78+
scale,
79+
1,
80+
)
81+
)
82+
83+
d_val = tl.load(hinv_ptr + col * BLOCK_SIZE + col)
84+
err_val = (w_scalar - q_scalar) / d_val
85+
tl.store(err_base + col, err_val)
86+
tl.store(qw_base + col, q_scalar)
87+
88+
remaining = (j_range > col) & (j_range < n_cols)
89+
hinv_row = tl.load(hinv_ptr + col * BLOCK_SIZE + j_range, mask=remaining, other=0.0)
90+
w_full = w_full - err_val * hinv_row
91+
92+
93+
def gptq_fused_block_scalar(
94+
w_block: torch.Tensor,
95+
scales_2d: torch.Tensor,
96+
h_inv_cho_blk: torch.Tensor,
97+
quant_block_size: int,
98+
block_start: int,
99+
n_cols: int,
100+
) -> tuple[torch.Tensor, torch.Tensor]:
101+
"""Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch.
102+
103+
Args:
104+
w_block: Working weights ``[num_rows, block_size]`` (float32).
105+
scales_2d: Pre-computed scales ``[num_rows, n_scale_blocks]`` (float32).
106+
h_inv_cho_blk: Block of upper-Cholesky inverse Hessian ``[block_size, block_size]``.
107+
quant_block_size: Number of elements sharing one scale factor.
108+
block_start: Column offset of this block in the full weight matrix.
109+
n_cols: Number of active columns in this block.
110+
111+
Returns:
112+
``(qw_block, err_block)`` each ``[num_rows, block_size]``.
113+
"""
114+
num_rows, block_size = w_block.shape
115+
116+
qw_block = torch.empty_like(w_block)
117+
err_block = torch.empty_like(w_block)
118+
119+
_gptq_scalar_kernel[(num_rows,)](
120+
w_block.contiguous(),
121+
qw_block,
122+
err_block,
123+
scales_2d.contiguous(),
124+
h_inv_cho_blk.contiguous(),
125+
num_rows,
126+
scales_2d.shape[1],
127+
quant_block_size,
128+
block_start,
129+
n_cols,
130+
BLOCK_SIZE=block_size,
131+
)
132+
133+
return qw_block, err_block
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Composable Triton JIT functions for NVFP4 (E2M1) fake quantization.
17+
18+
Single source of truth for FP4 decision-boundary rounding. Used by:
19+
- ``fp4_kernel.py`` (standalone blockwise fake quant)
20+
- ``fp4_kernel_hopper.py`` (Hopper block-pointer variant)
21+
- ``gptq_fused_kernel.py`` (fused GPTQ scalar path)
22+
23+
FP4 (E2M1) representable magnitudes: {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}
24+
"""
25+
26+
import triton
27+
import triton.language as tl
28+
from triton.language.extra.cuda import libdevice
29+
30+
31+
@triton.jit
32+
def fp4_round_magnitude(abs_scaled):
33+
"""Round ``|x| / scale`` to the nearest FP4 (E2M1) magnitude.
34+
35+
Works with any tensor shape — the caller is responsible for computing
36+
``abs_scaled = |x| / scale`` beforehand.
37+
38+
Returns:
39+
Tensor of same shape as *abs_scaled* with values in
40+
{0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}.
41+
"""
42+
return tl.where(
43+
abs_scaled <= 0.25,
44+
0.0,
45+
tl.where(
46+
abs_scaled < 0.75,
47+
0.5,
48+
tl.where(
49+
abs_scaled <= 1.25,
50+
1.0,
51+
tl.where(
52+
abs_scaled < 1.75,
53+
1.5,
54+
tl.where(
55+
abs_scaled <= 2.5,
56+
2.0,
57+
tl.where(abs_scaled < 3.5, 3.0, tl.where(abs_scaled <= 5.0, 4.0, 6.0)),
58+
),
59+
),
60+
),
61+
),
62+
)
63+
64+
65+
@triton.jit
66+
def nvfp4_scalar_quant(
67+
x, # [N] float32, already loaded
68+
scale, # float32 scalar: pre-computed block scale (amax / 6.0)
69+
N: tl.constexpr,
70+
):
71+
"""NVFP4 scalar fake quantization for a group of elements sharing one scale.
72+
73+
Quantizes each element independently: divide by scale, round to nearest
74+
FP4 (E2M1) value via ``fp4_round_magnitude``, multiply by scale.
75+
76+
Args:
77+
x: [N] float32 tensor of values to quantize (already in registers).
78+
scale: float32 scalar block scale.
79+
N: Compile-time number of elements.
80+
81+
Returns:
82+
x_quant: [N] float32, fake-quantized values.
83+
"""
84+
x_abs = tl.abs(x)
85+
# Guard against degenerate scale (matching CUDA kernel behavior)
86+
scale_safe = tl.where(
87+
(scale == 0.0) | libdevice.isnan(scale) | (tl.abs(scale) == float("inf")),
88+
1.0,
89+
scale,
90+
)
91+
abs_scaled = x_abs / scale_safe
92+
q_val = fp4_round_magnitude(abs_scaled)
93+
x_rescaled = q_val * scale_safe
94+
x_quant = tl.where(x >= 0, x_rescaled, -x_rescaled)
95+
return x_quant

0 commit comments

Comments
 (0)