Skip to content

Commit 1241957

Browse files
authored
feat: FP8 ViT Attention w/ FlashInfer (#1660)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 42769da commit 1241957

11 files changed

Lines changed: 684 additions & 41 deletions

File tree

aphrodite/config/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ class ModelConfig:
323323
mm_encoder_only: InitVar[bool | None] = None # type: ignore[assignment]
324324
mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None # type: ignore[assignment]
325325
mm_encoder_attn_backend: InitVar[AttentionBackendEnum | str | None] = None # type: ignore[assignment]
326+
mm_encoder_attn_dtype: InitVar[str | None] = None # type: ignore[assignment]
327+
mm_encoder_fp8_scale_path: InitVar[str | None] = None # type: ignore[assignment]
328+
mm_encoder_fp8_scale_save_path: InitVar[str | None] = None # type: ignore[assignment]
329+
mm_encoder_fp8_scale_save_margin: InitVar[float | None] = None # type: ignore[assignment]
326330
interleave_mm_strings: InitVar[bool | None] = None # type: ignore[assignment]
327331
skip_mm_profiling: InitVar[bool | None] = None # type: ignore[assignment]
328332
video_pruning_rate: InitVar[float | None] = None # type: ignore[assignment]
@@ -443,6 +447,10 @@ def __post_init__(
443447
mm_encoder_only: bool | None,
444448
mm_encoder_tp_mode: MMEncoderTPMode | None,
445449
mm_encoder_attn_backend: AttentionBackendEnum | str | None,
450+
mm_encoder_attn_dtype: str | None,
451+
mm_encoder_fp8_scale_path: str | None,
452+
mm_encoder_fp8_scale_save_path: str | None,
453+
mm_encoder_fp8_scale_save_margin: float | None,
446454
interleave_mm_strings: bool | None,
447455
skip_mm_profiling: bool | None,
448456
video_pruning_rate: float | None,
@@ -631,6 +639,10 @@ def __post_init__(
631639
mm_encoder_only=mm_encoder_only,
632640
mm_encoder_tp_mode=mm_encoder_tp_mode,
633641
mm_encoder_attn_backend=mm_encoder_attn_backend,
642+
mm_encoder_attn_dtype=mm_encoder_attn_dtype,
643+
mm_encoder_fp8_scale_path=mm_encoder_fp8_scale_path,
644+
mm_encoder_fp8_scale_save_path=mm_encoder_fp8_scale_save_path,
645+
mm_encoder_fp8_scale_save_margin=mm_encoder_fp8_scale_save_margin,
634646
interleave_mm_strings=interleave_mm_strings,
635647
skip_mm_profiling=skip_mm_profiling,
636648
video_pruning_rate=video_pruning_rate,

aphrodite/config/multimodal.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Mapping
5+
from pathlib import Path
56
from typing import Any, Literal, TypeAlias, TypedDict, final
67

78
from pydantic import ConfigDict, Field, field_validator, model_validator
@@ -158,6 +159,24 @@ class MultiModalConfig:
158159
"""Optional override for the multi-modal encoder attention backend when
159160
using vision transformers. Accepts any value from
160161
`aphrodite.v1.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
162+
mm_encoder_attn_dtype: Literal["fp8"] | None = None
163+
"""Optional dtype override for ViT encoder attention. Set to `"fp8"` to
164+
enable FP8 quantization via the FlashInfer cuDNN backend. When set to
165+
`"fp8"` without a scale file, dynamic scaling is used automatically.
166+
See docs/features/quantization/fp8_vit_attn.md for details."""
167+
mm_encoder_fp8_scale_path: str | None = None
168+
"""Path to a JSON file containing per-layer FP8 Q/K/V scales for ViT
169+
encoder attention. When provided (with `mm_encoder_attn_dtype="fp8"`),
170+
static scaling is used. When omitted, dynamic scaling is used."""
171+
mm_encoder_fp8_scale_save_path: str | None = None
172+
"""When set with dynamic FP8 scaling (`mm_encoder_attn_dtype="fp8"`
173+
and no `mm_encoder_fp8_scale_path`), saves the calibrated scales to
174+
this file after the amax history buffer is full. The saved file can
175+
then be used as `mm_encoder_fp8_scale_path` in subsequent runs."""
176+
mm_encoder_fp8_scale_save_margin: float = Field(default=1.5, gt=0.0)
177+
"""Safety margin multiplied onto scales when auto-saving. A value > 1
178+
leaves headroom so that inputs with larger activations than the
179+
calibration set do not overflow FP8 range. Default 1.5."""
161180
interleave_mm_strings: bool = False
162181
"""Enable fully interleaved support for multimodal prompts, while using
163182
--chat-template-content-format=string."""
@@ -227,6 +246,30 @@ def _validate_multimodal_config(self):
227246
raise ValueError(
228247
"'mm_shm_cache_max_object_size_mb' should only be set when 'mm_processor_cache_type' is 'shm'."
229248
)
249+
# Validate FP8 scale path combinations.
250+
if self.mm_encoder_attn_dtype != "fp8" and (
251+
self.mm_encoder_fp8_scale_path is not None or self.mm_encoder_fp8_scale_save_path is not None
252+
):
253+
raise ValueError(
254+
"'mm_encoder_fp8_scale_path' and "
255+
"'mm_encoder_fp8_scale_save_path' require "
256+
"'mm_encoder_attn_dtype' to be 'fp8'."
257+
)
258+
if self.mm_encoder_fp8_scale_path is not None and self.mm_encoder_fp8_scale_save_path is not None:
259+
raise ValueError(
260+
"'mm_encoder_fp8_scale_save_path' cannot be used with "
261+
"'mm_encoder_fp8_scale_path' (saving requires dynamic scaling)."
262+
)
263+
264+
# Validate file paths exist.
265+
if self.mm_encoder_fp8_scale_path is not None:
266+
scale_path = Path(self.mm_encoder_fp8_scale_path)
267+
if not scale_path.is_file():
268+
raise FileNotFoundError(f"FP8 scale file not found: {scale_path}")
269+
if self.mm_encoder_fp8_scale_save_path is not None:
270+
save_parent = Path(self.mm_encoder_fp8_scale_save_path).parent
271+
if not save_parent.is_dir():
272+
raise FileNotFoundError(f"Parent directory for FP8 scale save path not found: {save_parent}")
230273
return self
231274

232275
def compute_hash(self) -> str:
@@ -244,6 +287,8 @@ def compute_hash(self) -> str:
244287
factors: list[Any] = [
245288
self.mm_encoder_attn_backend.name if self.mm_encoder_attn_backend is not None else None,
246289
self.mm_encoder_tp_mode,
290+
self.mm_encoder_attn_dtype,
291+
self.mm_encoder_fp8_scale_path,
247292
]
248293
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
249294
return hash_str

aphrodite/engine/arg_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ class EngineArgs:
508508
mm_encoder_only: bool = MultiModalConfig.mm_encoder_only
509509
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
510510
mm_encoder_attn_backend: AttentionBackendEnum | str | None = MultiModalConfig.mm_encoder_attn_backend
511+
mm_encoder_attn_dtype: str | None = MultiModalConfig.mm_encoder_attn_dtype
512+
mm_encoder_fp8_scale_path: str | None = MultiModalConfig.mm_encoder_fp8_scale_path
513+
mm_encoder_fp8_scale_save_path: str | None = MultiModalConfig.mm_encoder_fp8_scale_save_path
514+
mm_encoder_fp8_scale_save_margin: float = MultiModalConfig.mm_encoder_fp8_scale_save_margin
511515
io_processor_plugin: str | None = None
512516
renderer_num_workers: int = 1
513517
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
@@ -1015,6 +1019,22 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10151019
"--mm-encoder-attn-backend",
10161020
**multimodal_kwargs["mm_encoder_attn_backend"],
10171021
)
1022+
multimodal_group.add_argument(
1023+
"--mm-encoder-attn-dtype",
1024+
**multimodal_kwargs["mm_encoder_attn_dtype"],
1025+
)
1026+
multimodal_group.add_argument(
1027+
"--mm-encoder-fp8-scale-path",
1028+
**multimodal_kwargs["mm_encoder_fp8_scale_path"],
1029+
)
1030+
multimodal_group.add_argument(
1031+
"--mm-encoder-fp8-scale-save-path",
1032+
**multimodal_kwargs["mm_encoder_fp8_scale_save_path"],
1033+
)
1034+
multimodal_group.add_argument(
1035+
"--mm-encoder-fp8-scale-save-margin",
1036+
**multimodal_kwargs["mm_encoder_fp8_scale_save_margin"],
1037+
)
10181038
multimodal_group.add_argument("--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"])
10191039
multimodal_group.add_argument("--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"])
10201040

@@ -1302,6 +1322,10 @@ def create_model_config(self) -> ModelConfig:
13021322
mm_encoder_only=self.mm_encoder_only,
13031323
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
13041324
mm_encoder_attn_backend=self.mm_encoder_attn_backend,
1325+
mm_encoder_attn_dtype=self.mm_encoder_attn_dtype,
1326+
mm_encoder_fp8_scale_path=self.mm_encoder_fp8_scale_path,
1327+
mm_encoder_fp8_scale_save_path=self.mm_encoder_fp8_scale_save_path,
1328+
mm_encoder_fp8_scale_save_margin=self.mm_encoder_fp8_scale_save_margin,
13051329
pooler_config=self.pooler_config,
13061330
generation_config=self.generation_config,
13071331
override_generation_config=self.override_generation_config,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Triton kernel implementations."""
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Stride-aware FP8 quantization with head_dim padding for ViT attention.
4+
5+
Reads directly from non-contiguous QKV views using 3D strides and pads
6+
head_dim to a multiple of 16 for cuDNN compatibility.
7+
"""
8+
9+
import torch
10+
11+
from aphrodite.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
12+
from aphrodite.model_executor.layers.quantization.utils.quant_utils import (
13+
get_fp8_min_max,
14+
)
15+
from aphrodite.platforms import current_platform
16+
from aphrodite.triton_utils import HAS_TRITON, tl, triton
17+
from aphrodite.utils.math_utils import round_up
18+
19+
_FP8_MIN, _FP8_MAX = get_fp8_min_max()
20+
21+
22+
@triton.jit
23+
def _quantize_pad_fp8_kernel(
24+
x_ptr,
25+
y_ptr,
26+
scale_ptr,
27+
stride_xs,
28+
stride_xh,
29+
stride_xd,
30+
stride_ys,
31+
stride_yh,
32+
stride_yd,
33+
num_heads,
34+
n_rows,
35+
n_cols,
36+
n_cols_padded,
37+
fp8_min,
38+
fp8_max,
39+
SKIP_SCALE: tl.constexpr,
40+
BLOCK_M: tl.constexpr,
41+
BLOCK_N: tl.constexpr,
42+
):
43+
pid_m = tl.program_id(0)
44+
pid_n = tl.program_id(1)
45+
46+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
47+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
48+
mask_m = offs_m < n_rows
49+
mask_out = mask_m[:, None] & (offs_n[None, :] < n_cols_padded)
50+
mask_in = mask_m[:, None] & (offs_n[None, :] < n_cols)
51+
52+
# Decompose flattened row into (token, head) for 3D stride indexing.
53+
s = offs_m // num_heads
54+
h = offs_m % num_heads
55+
56+
x_ptrs = x_ptr + s[:, None] * stride_xs + h[:, None] * stride_xh + offs_n[None, :] * stride_xd
57+
x = tl.load(x_ptrs, mask=mask_in, other=0.0).to(tl.float32)
58+
if SKIP_SCALE:
59+
x_q = x
60+
else:
61+
scale = tl.load(scale_ptr)
62+
x_q = x / scale
63+
x_q = tl.clamp(x_q, fp8_min, fp8_max).to(y_ptr.dtype.element_ty)
64+
65+
y_ptrs = y_ptr + s[:, None] * stride_ys + h[:, None] * stride_yh + offs_n[None, :] * stride_yd
66+
tl.store(y_ptrs, x_q, mask=mask_out)
67+
68+
69+
def _get_fp8_pad_quant_config(padded_head_dim: int) -> tuple[int, int, int]:
70+
block_n = triton.next_power_of_2(padded_head_dim)
71+
block_n = max(16, min(block_n, 128))
72+
block_m = 16
73+
num_warps = 4
74+
return block_m, block_n, num_warps
75+
76+
77+
def quantize_fp8_pad_head_dim_triton(
78+
tensor: torch.Tensor,
79+
scale: torch.Tensor,
80+
skip_scale: bool = False,
81+
block_m: int | None = None,
82+
block_n: int | None = None,
83+
num_warps: int | None = None,
84+
) -> torch.Tensor:
85+
"""Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16.
86+
87+
Reads directly from the input using its 3D strides, so non-contiguous
88+
views (e.g. Q/K/V slices from an interleaved QKV buffer) are handled
89+
without an extra copy. Output is always a fresh contiguous tensor
90+
with shape (S, H, padded_D).
91+
"""
92+
if not HAS_TRITON:
93+
raise RuntimeError("Triton is required to quantize with head_dim padding.")
94+
95+
original_shape = tensor.shape
96+
if tensor.dim() == 4:
97+
tensor = tensor.view(-1, tensor.shape[-2], tensor.shape[-1])
98+
assert tensor.dim() == 3, f"Expected 3D input (S, H, D), got {tensor.dim()}D"
99+
S, H, D = tensor.shape
100+
padded_head_dim = round_up(D, 16)
101+
out_dtype = current_platform.fp8_dtype()
102+
output = torch.empty(
103+
(S, H, padded_head_dim),
104+
device=tensor.device,
105+
dtype=out_dtype,
106+
)
107+
108+
scale_1d = scale.reshape(-1)
109+
n_rows = S * H
110+
111+
if block_m is None or block_n is None or num_warps is None:
112+
block_m, block_n, num_warps = _get_fp8_pad_quant_config(padded_head_dim)
113+
114+
grid = (
115+
triton.cdiv(n_rows, block_m),
116+
triton.cdiv(padded_head_dim, block_n),
117+
)
118+
119+
_quantize_pad_fp8_kernel[grid](
120+
tensor,
121+
output,
122+
scale_1d,
123+
tensor.stride(0),
124+
tensor.stride(1),
125+
tensor.stride(2),
126+
output.stride(0),
127+
output.stride(1),
128+
output.stride(2),
129+
H,
130+
n_rows,
131+
D,
132+
padded_head_dim,
133+
_FP8_MIN,
134+
_FP8_MAX,
135+
SKIP_SCALE=skip_scale,
136+
BLOCK_M=block_m,
137+
BLOCK_N=block_n,
138+
num_warps=num_warps,
139+
)
140+
141+
return output.view((*original_shape[:-1], padded_head_dim))
142+
143+
144+
def quantize_fp8_maybe_pad_head_dim(
145+
tensor: torch.Tensor,
146+
scale: torch.Tensor,
147+
fp8_quant: QuantFP8,
148+
skip_scale: bool = False,
149+
) -> torch.Tensor:
150+
"""Quantize a 3D/4D tensor to FP8, padding head_dim to a multiple of 16
151+
only when needed.
152+
153+
Accepts (S, H, D) or (B, S, H, D) input. Uses ``fp8_quant`` (a
154+
:class:`QuantFP8` CustomOp) when head_dim is already aligned to 16
155+
(no padding); otherwise falls back to a stride-aware Triton kernel
156+
that pads head_dim to a multiple of 16.
157+
"""
158+
head_dim = tensor.shape[-1]
159+
if head_dim % 16 != 0:
160+
return quantize_fp8_pad_head_dim_triton(tensor, scale, skip_scale=skip_scale)
161+
162+
if skip_scale:
163+
return tensor.to(current_platform.fp8_dtype())
164+
165+
# QuantFP8 expects 2D: flatten all dims except (H, D).
166+
orig_shape = tensor.shape
167+
total_tokens = tensor.numel() // (orig_shape[-1] * orig_shape[-2])
168+
tensor_2d = tensor.reshape(total_tokens, -1)
169+
fp8_tensor, _ = fp8_quant(tensor_2d, scale=scale)
170+
return fp8_tensor.reshape(orig_shape)

0 commit comments

Comments
 (0)