|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +"""Stride-aware FP8 quantization with head_dim padding for ViT attention. |
| 4 | +
|
| 5 | +Reads directly from non-contiguous QKV views using 3D strides and pads |
| 6 | +head_dim to a multiple of 16 for cuDNN compatibility. |
| 7 | +""" |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +from aphrodite.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 |
| 12 | +from aphrodite.model_executor.layers.quantization.utils.quant_utils import ( |
| 13 | + get_fp8_min_max, |
| 14 | +) |
| 15 | +from aphrodite.platforms import current_platform |
| 16 | +from aphrodite.triton_utils import HAS_TRITON, tl, triton |
| 17 | +from aphrodite.utils.math_utils import round_up |
| 18 | + |
| 19 | +_FP8_MIN, _FP8_MAX = get_fp8_min_max() |
| 20 | + |
| 21 | + |
| 22 | +@triton.jit |
| 23 | +def _quantize_pad_fp8_kernel( |
| 24 | + x_ptr, |
| 25 | + y_ptr, |
| 26 | + scale_ptr, |
| 27 | + stride_xs, |
| 28 | + stride_xh, |
| 29 | + stride_xd, |
| 30 | + stride_ys, |
| 31 | + stride_yh, |
| 32 | + stride_yd, |
| 33 | + num_heads, |
| 34 | + n_rows, |
| 35 | + n_cols, |
| 36 | + n_cols_padded, |
| 37 | + fp8_min, |
| 38 | + fp8_max, |
| 39 | + SKIP_SCALE: tl.constexpr, |
| 40 | + BLOCK_M: tl.constexpr, |
| 41 | + BLOCK_N: tl.constexpr, |
| 42 | +): |
| 43 | + pid_m = tl.program_id(0) |
| 44 | + pid_n = tl.program_id(1) |
| 45 | + |
| 46 | + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 47 | + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 48 | + mask_m = offs_m < n_rows |
| 49 | + mask_out = mask_m[:, None] & (offs_n[None, :] < n_cols_padded) |
| 50 | + mask_in = mask_m[:, None] & (offs_n[None, :] < n_cols) |
| 51 | + |
| 52 | + # Decompose flattened row into (token, head) for 3D stride indexing. |
| 53 | + s = offs_m // num_heads |
| 54 | + h = offs_m % num_heads |
| 55 | + |
| 56 | + x_ptrs = x_ptr + s[:, None] * stride_xs + h[:, None] * stride_xh + offs_n[None, :] * stride_xd |
| 57 | + x = tl.load(x_ptrs, mask=mask_in, other=0.0).to(tl.float32) |
| 58 | + if SKIP_SCALE: |
| 59 | + x_q = x |
| 60 | + else: |
| 61 | + scale = tl.load(scale_ptr) |
| 62 | + x_q = x / scale |
| 63 | + x_q = tl.clamp(x_q, fp8_min, fp8_max).to(y_ptr.dtype.element_ty) |
| 64 | + |
| 65 | + y_ptrs = y_ptr + s[:, None] * stride_ys + h[:, None] * stride_yh + offs_n[None, :] * stride_yd |
| 66 | + tl.store(y_ptrs, x_q, mask=mask_out) |
| 67 | + |
| 68 | + |
| 69 | +def _get_fp8_pad_quant_config(padded_head_dim: int) -> tuple[int, int, int]: |
| 70 | + block_n = triton.next_power_of_2(padded_head_dim) |
| 71 | + block_n = max(16, min(block_n, 128)) |
| 72 | + block_m = 16 |
| 73 | + num_warps = 4 |
| 74 | + return block_m, block_n, num_warps |
| 75 | + |
| 76 | + |
| 77 | +def quantize_fp8_pad_head_dim_triton( |
| 78 | + tensor: torch.Tensor, |
| 79 | + scale: torch.Tensor, |
| 80 | + skip_scale: bool = False, |
| 81 | + block_m: int | None = None, |
| 82 | + block_n: int | None = None, |
| 83 | + num_warps: int | None = None, |
| 84 | +) -> torch.Tensor: |
| 85 | + """Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16. |
| 86 | +
|
| 87 | + Reads directly from the input using its 3D strides, so non-contiguous |
| 88 | + views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled |
| 89 | + without an extra copy. Output is always a fresh contiguous tensor |
| 90 | + with shape (S, H, padded_D). |
| 91 | + """ |
| 92 | + if not HAS_TRITON: |
| 93 | + raise RuntimeError("Triton is required to quantize with head_dim padding.") |
| 94 | + |
| 95 | + original_shape = tensor.shape |
| 96 | + if tensor.dim() == 4: |
| 97 | + tensor = tensor.view(-1, tensor.shape[-2], tensor.shape[-1]) |
| 98 | + assert tensor.dim() == 3, f"Expected 3D input (S, H, D), got {tensor.dim()}D" |
| 99 | + S, H, D = tensor.shape |
| 100 | + padded_head_dim = round_up(D, 16) |
| 101 | + out_dtype = current_platform.fp8_dtype() |
| 102 | + output = torch.empty( |
| 103 | + (S, H, padded_head_dim), |
| 104 | + device=tensor.device, |
| 105 | + dtype=out_dtype, |
| 106 | + ) |
| 107 | + |
| 108 | + scale_1d = scale.reshape(-1) |
| 109 | + n_rows = S * H |
| 110 | + |
| 111 | + if block_m is None or block_n is None or num_warps is None: |
| 112 | + block_m, block_n, num_warps = _get_fp8_pad_quant_config(padded_head_dim) |
| 113 | + |
| 114 | + grid = ( |
| 115 | + triton.cdiv(n_rows, block_m), |
| 116 | + triton.cdiv(padded_head_dim, block_n), |
| 117 | + ) |
| 118 | + |
| 119 | + _quantize_pad_fp8_kernel[grid]( |
| 120 | + tensor, |
| 121 | + output, |
| 122 | + scale_1d, |
| 123 | + tensor.stride(0), |
| 124 | + tensor.stride(1), |
| 125 | + tensor.stride(2), |
| 126 | + output.stride(0), |
| 127 | + output.stride(1), |
| 128 | + output.stride(2), |
| 129 | + H, |
| 130 | + n_rows, |
| 131 | + D, |
| 132 | + padded_head_dim, |
| 133 | + _FP8_MIN, |
| 134 | + _FP8_MAX, |
| 135 | + SKIP_SCALE=skip_scale, |
| 136 | + BLOCK_M=block_m, |
| 137 | + BLOCK_N=block_n, |
| 138 | + num_warps=num_warps, |
| 139 | + ) |
| 140 | + |
| 141 | + return output.view((*original_shape[:-1], padded_head_dim)) |
| 142 | + |
| 143 | + |
| 144 | +def quantize_fp8_maybe_pad_head_dim( |
| 145 | + tensor: torch.Tensor, |
| 146 | + scale: torch.Tensor, |
| 147 | + fp8_quant: QuantFP8, |
| 148 | + skip_scale: bool = False, |
| 149 | +) -> torch.Tensor: |
| 150 | + """Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16 |
| 151 | + only when needed. |
| 152 | +
|
| 153 | + Accepts (S, H, D) or (B, S, H, D) input. Uses ``fp8_quant`` (a |
| 154 | + :class:`QuantFP8` CustomOp) when head_dim is already aligned to 16 |
| 155 | + (no padding); otherwise falls back to a stride-aware Triton kernel |
| 156 | + that pads head_dim to a multiple of 16. |
| 157 | + """ |
| 158 | + head_dim = tensor.shape[-1] |
| 159 | + if head_dim % 16 != 0: |
| 160 | + return quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=skip_scale) |
| 161 | + |
| 162 | + if skip_scale: |
| 163 | + return tensor.to(current_platform.fp8_dtype()) |
| 164 | + |
| 165 | + # QuantFP8 expects 2D: flatten all dims except (H, D). |
| 166 | + orig_shape = tensor.shape |
| 167 | + total_tokens = tensor.numel() // (orig_shape[-1] * orig_shape[-2]) |
| 168 | + tensor_2d = tensor.reshape(total_tokens, -1) |
| 169 | + fp8_tensor, _ = fp8_quant(tensor_2d, scale=scale) |
| 170 | + return fp8_tensor.reshape(orig_shape) |
0 commit comments