|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. |
| 17 | +
|
| 18 | +Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single |
| 19 | +kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates |
| 20 | +and emits the per-block ``best_amax`` directly. |
| 21 | +
|
| 22 | +The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see |
| 23 | +:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on |
| 24 | +the per-block scale is the identity, so the kernel can use |
| 25 | +``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it |
| 26 | +runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). |
| 27 | +
|
| 28 | +Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. |
| 29 | +""" |
| 30 | + |
| 31 | +import torch |
| 32 | +import triton |
| 33 | +import triton.language as tl |
| 34 | + |
| 35 | +from ._fp8_scale_candidates import fp8_scale_candidates |
| 36 | +from .nvfp4_quant import fp4_round_magnitude |
| 37 | + |
| 38 | +__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] |
| 39 | + |
| 40 | + |
| 41 | +# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: |
| 42 | +# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms |
| 43 | +# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 |
| 44 | +# would underfill the SMs. |
| 45 | +_FP8_SWEEP_AUTOTUNE_CONFIGS = [ |
| 46 | + triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), |
| 47 | + triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), |
| 48 | + triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), |
| 49 | +] |
| 50 | + |
| 51 | + |
| 52 | +@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) |
| 53 | +@triton.jit |
| 54 | +def _fp8_scale_sweep_kernel( |
| 55 | + x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) |
| 56 | + candidates_ptr, # [NUM_CANDIDATES] fp32 |
| 57 | + global_amax_ptr, # scalar fp32 |
| 58 | + best_amax_ptr, # [N_BLOCKS] fp32 output |
| 59 | + N_BLOCKS, |
| 60 | + BLOCK_SIZE: tl.constexpr, |
| 61 | + NUM_CANDIDATES: tl.constexpr, |
| 62 | + BLOCKS_PER_PROGRAM: tl.constexpr, |
| 63 | +): |
| 64 | + pid = tl.program_id(axis=0) |
| 65 | + block_start = pid * BLOCKS_PER_PROGRAM |
| 66 | + block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) |
| 67 | + block_mask = block_idx < N_BLOCKS |
| 68 | + |
| 69 | + # Load weights for this tile and pre-compute their absolute values once. |
| 70 | + # The squared error is sign-invariant since FP4 quant preserves sign: |
| 71 | + # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 |
| 72 | + # so we never need ``w`` itself again, dropping a tl.where + negation per element. |
| 73 | + elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] |
| 74 | + elem_mask = block_mask[:, None] |
| 75 | + w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) |
| 76 | + |
| 77 | + global_amax = tl.load(global_amax_ptr).to(tl.float32) |
| 78 | + |
| 79 | + best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) |
| 80 | + best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) |
| 81 | + |
| 82 | + # Loop over the 126 FP8 candidates (compile-time unrolled). |
| 83 | + # Scales are guaranteed positive and finite (constructed from a positive candidate |
| 84 | + # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is |
| 85 | + # unnecessary apart from the global_amax == 0 case handled below. |
| 86 | + for k in tl.static_range(NUM_CANDIDATES): |
| 87 | + c = tl.load(candidates_ptr + k).to(tl.float32) |
| 88 | + scale = c * global_amax / 6.0 |
| 89 | + # Avoid divide-by-zero when global_amax == 0; in that case w_abs is also zero |
| 90 | + # (global_amax = max|w|), so the loss is zero for every candidate either way. |
| 91 | + scale_safe = tl.where(scale == 0.0, 1.0, scale) |
| 92 | + q_mag = fp4_round_magnitude(w_abs / scale_safe) |
| 93 | + diff = w_abs - q_mag * scale_safe |
| 94 | + loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] |
| 95 | + is_better = loss < best_loss |
| 96 | + best_loss = tl.where(is_better, loss, best_loss) |
| 97 | + best_idx = tl.where(is_better, k, best_idx) |
| 98 | + |
| 99 | + # Map each block's winning candidate index back to its amax = global_amax * c[best]. |
| 100 | + best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) |
| 101 | + best_amax = global_amax * best_c |
| 102 | + tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) |
| 103 | + |
| 104 | + |
| 105 | +def nvfp4_fp8_scale_sweep( |
| 106 | + x: torch.Tensor, |
| 107 | + global_amax: torch.Tensor, |
| 108 | + block_size: int = 16, |
| 109 | +) -> torch.Tensor: |
| 110 | + """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. |
| 111 | +
|
| 112 | + Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into |
| 113 | + a single Triton kernel: every block's weight elements are loaded once, all 126 |
| 114 | + candidates are evaluated in registers, and the running argmin is kept inline. |
| 115 | +
|
| 116 | + Args: |
| 117 | + x: Weight tensor on CUDA. Total element count must be divisible by |
| 118 | + ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. |
| 119 | + global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). |
| 120 | + block_size: NVFP4 block size (typically 16). |
| 121 | +
|
| 122 | + Returns: |
| 123 | + ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. |
| 124 | + """ |
| 125 | + if not x.is_cuda: |
| 126 | + raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") |
| 127 | + if not isinstance(block_size, int) or block_size <= 0: |
| 128 | + raise ValueError(f"block_size must be a positive int, got {block_size!r}.") |
| 129 | + if x.numel() % block_size != 0: |
| 130 | + raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") |
| 131 | + |
| 132 | + candidates = fp8_scale_candidates(x.device).to(dtype=torch.float32) |
| 133 | + |
| 134 | + n_blocks = x.numel() // block_size |
| 135 | + x_flat = x.contiguous().view(-1) |
| 136 | + global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) |
| 137 | + best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) |
| 138 | + |
| 139 | + grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) |
| 140 | + with torch.cuda.device(x.device): |
| 141 | + _fp8_scale_sweep_kernel[grid]( |
| 142 | + x_flat, |
| 143 | + candidates, |
| 144 | + global_amax_f32, |
| 145 | + best_amax, |
| 146 | + n_blocks, |
| 147 | + BLOCK_SIZE=block_size, |
| 148 | + NUM_CANDIDATES=int(candidates.numel()), |
| 149 | + ) |
| 150 | + return best_amax |
0 commit comments