|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Fused Triton kernels for GPTQ blockwise weight-update. |
| 17 | +
|
| 18 | +A kernel for scalar (NVFP4) quantization. |
| 19 | +Each kernel fuses quantization + per-column GPTQ error propagation into |
| 20 | +one launch per GPTQ block, avoiding the Python-level per-column loop. |
| 21 | +
|
| 22 | +Architecture (both kernels): |
| 23 | + - One Triton program per output row. |
| 24 | + - ``w_full [BLOCK_SIZE]`` register tensor holds working weights. |
| 25 | + - Per-column error propagation: ``w_full -= err * h_inv_row``. |
| 26 | +
|
| 27 | +Scalar kernel (``_gptq_scalar_kernel``): |
| 28 | + - Calls ``nvfp4_scalar_quant()`` from ``nvfp4_quant.py`` per column. |
| 29 | +""" |
| 30 | + |
| 31 | +import torch |
| 32 | +import triton |
| 33 | +import triton.language as tl |
| 34 | + |
| 35 | +from .nvfp4_quant import nvfp4_scalar_quant |
| 36 | + |
| 37 | +__all__ = ["gptq_fused_block_scalar"] |
| 38 | + |
| 39 | + |
| 40 | +# --------------------------------------------------------------------------- |
| 41 | +# Scalar kernel — NVFP4 quantization + error propagation |
| 42 | +# --------------------------------------------------------------------------- |
| 43 | + |
| 44 | + |
| 45 | +@triton.jit |
| 46 | +def _gptq_scalar_kernel( |
| 47 | + w_ptr, |
| 48 | + qw_ptr, |
| 49 | + err_ptr, |
| 50 | + scales_ptr, |
| 51 | + hinv_ptr, |
| 52 | + num_rows, |
| 53 | + n_scale_blocks, |
| 54 | + quant_block_size, |
| 55 | + block_start, |
| 56 | + n_cols, |
| 57 | + BLOCK_SIZE: tl.constexpr, |
| 58 | +): |
| 59 | + row = tl.program_id(0) |
| 60 | + if row >= num_rows: |
| 61 | + return |
| 62 | + |
| 63 | + w_base = w_ptr + row * BLOCK_SIZE |
| 64 | + qw_base = qw_ptr + row * BLOCK_SIZE |
| 65 | + err_base = err_ptr + row * BLOCK_SIZE |
| 66 | + scales_base = scales_ptr + row * n_scale_blocks |
| 67 | + |
| 68 | + j_range = tl.arange(0, BLOCK_SIZE) |
| 69 | + w_full = tl.load(w_base + j_range, mask=j_range < n_cols, other=0.0) |
| 70 | + |
| 71 | + for col in range(0, BLOCK_SIZE, 1): |
| 72 | + scale = tl.load(scales_base + (block_start + col) // quant_block_size) |
| 73 | + |
| 74 | + w_scalar = tl.sum(tl.where(j_range == col, w_full, 0.0)) |
| 75 | + q_scalar = tl.sum( |
| 76 | + nvfp4_scalar_quant( |
| 77 | + tl.full([1], w_scalar, dtype=tl.float32), |
| 78 | + scale, |
| 79 | + 1, |
| 80 | + ) |
| 81 | + ) |
| 82 | + |
| 83 | + d_val = tl.load(hinv_ptr + col * BLOCK_SIZE + col) |
| 84 | + err_val = (w_scalar - q_scalar) / d_val |
| 85 | + tl.store(err_base + col, err_val) |
| 86 | + tl.store(qw_base + col, q_scalar) |
| 87 | + |
| 88 | + remaining = (j_range > col) & (j_range < n_cols) |
| 89 | + hinv_row = tl.load(hinv_ptr + col * BLOCK_SIZE + j_range, mask=remaining, other=0.0) |
| 90 | + w_full = w_full - err_val * hinv_row |
| 91 | + |
| 92 | + |
| 93 | +def gptq_fused_block_scalar( |
| 94 | + w_block: torch.Tensor, |
| 95 | + scales_2d: torch.Tensor, |
| 96 | + h_inv_cho_blk: torch.Tensor, |
| 97 | + quant_block_size: int, |
| 98 | + block_start: int, |
| 99 | + n_cols: int, |
| 100 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 101 | + """Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch. |
| 102 | +
|
| 103 | + Args: |
| 104 | + w_block: Working weights ``[num_rows, block_size]`` (float32). |
| 105 | + scales_2d: Pre-computed scales ``[num_rows, n_scale_blocks]`` (float32). |
| 106 | + h_inv_cho_blk: Block of upper-Cholesky inverse Hessian ``[block_size, block_size]``. |
| 107 | + quant_block_size: Number of elements sharing one scale factor. |
| 108 | + block_start: Column offset of this block in the full weight matrix. |
| 109 | + n_cols: Number of active columns in this block. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + ``(qw_block, err_block)`` each ``[num_rows, block_size]``. |
| 113 | + """ |
| 114 | + num_rows, block_size = w_block.shape |
| 115 | + |
| 116 | + qw_block = torch.empty_like(w_block) |
| 117 | + err_block = torch.empty_like(w_block) |
| 118 | + |
| 119 | + _gptq_scalar_kernel[(num_rows,)]( |
| 120 | + w_block.contiguous(), |
| 121 | + qw_block, |
| 122 | + err_block, |
| 123 | + scales_2d.contiguous(), |
| 124 | + h_inv_cho_blk.contiguous(), |
| 125 | + num_rows, |
| 126 | + scales_2d.shape[1], |
| 127 | + quant_block_size, |
| 128 | + block_start, |
| 129 | + n_cols, |
| 130 | + BLOCK_SIZE=block_size, |
| 131 | + ) |
| 132 | + |
| 133 | + return qw_block, err_block |
0 commit comments