diff --git a/aphrodite/_ipex_ops.py b/aphrodite/_ipex_ops.py index 7df00c6a35..3762b01c84 100644 --- a/aphrodite/_ipex_ops.py +++ b/aphrodite/_ipex_ops.py @@ -260,21 +260,23 @@ def reshape_and_cache_flash( @staticmethod def flash_attn_varlen_func( - out: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, - seqused_k: torch.Tensor, # we don't support this in ipex kernel max_seqlen_q: int, max_seqlen_k: int, - softmax_scale: float, - causal: bool, - block_table: torch.Tensor, - alibi_slopes: torch.Tensor | None, + softmax_scale: float | None = None, + causal: bool = False, + out: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + alibi_slopes: torch.Tensor | None = None, window_size: list[int] | None = None, softcap: float | None = 0.0, + seqused_k: torch.Tensor | None = None, cu_seqlens_k: torch.Tensor | None = None, + # passed in qwen vl + dropout_p: float = 0.0, # The following parameters are not used in ipex kernel currently, # we keep API compatible to CUDA's. scheduler_metadata=None, @@ -285,31 +287,60 @@ def flash_attn_varlen_func( num_splits=0, s_aux: torch.Tensor | None = None, ): + if out is None: + out = torch.empty(q.shape, dtype=q.dtype, device=q.device) real_window_size: tuple[int, int] if window_size is None: real_window_size = (-1, -1) else: assert len(window_size) == 2 real_window_size = (window_size[0], window_size[1]) - return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( - out, - q.contiguous(), - k, - v, - cu_seqlens_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - softmax_scale, - causal, - block_table, - alibi_slopes, - softcap=softcap, - window_size_left=real_window_size[0], - window_size_right=real_window_size[1], - k_scale=1.0, - v_scale=1.0, - ) + if block_table is None: + assert cu_seqlens_k is not None, "cu_seqlens_k can't be None when calling varlen_attention." + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + ipex_ops.varlen_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + out, + cu_seqlens_q, + cu_seqlens_k, + None, + max_seqlen_q, + max_seqlen_k, + 0.0, + softmax_scale, + False, + causal, + False, + None, + real_window_size[0], + real_window_size[1], + -1, + ) + return out + else: + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + q.contiguous(), + k, + v, + cu_seqlens_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + block_table, + alibi_slopes, + sink=s_aux, + softcap=softcap, + window_size_left=real_window_size[0], + window_size_right=real_window_size[1], + k_scale=1.0, + v_scale=1.0, + ) @staticmethod def get_scheduler_metadata( diff --git a/aphrodite/attention/layer.py b/aphrodite/attention/layer.py index 6e1294db45..e3a20fa4e1 100644 --- a/aphrodite/attention/layer.py +++ b/aphrodite/attention/layer.py @@ -100,6 +100,9 @@ def maybe_get_vit_flash_attn_backend( if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(torch.get_default_dtype()): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True + elif current_platform.is_xpu(): + assert attn_backend == _Backend.FLASH_ATTN, "XPU platform only supports FLASH_ATTN as vision attention backend." + use_upstream_fa = False else: return _Backend.TORCH_SDPA, None @@ -110,7 +113,7 @@ def maybe_get_vit_flash_attn_backend( if use_upstream_fa: from flash_attn import flash_attn_varlen_func else: - from aphrodite.aphrodite_flash_attn import flash_attn_varlen_func + from aphrodite.attention.utils.fa_utils import flash_attn_varlen_func else: flash_attn_varlen_func = None @@ -473,22 +476,18 @@ def __init__( # If aphrodite native fa is selected, we use it directly. use_upstream_fa = False - if current_platform.is_xpu(): - # currently, only torch_sdpa is supported on xpu - self.attn_backend = _Backend.TORCH_SDPA - else: - self.attn_backend = ( - backend - if backend - in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, - } - else _Backend.TORCH_SDPA - ) + self.attn_backend = ( + backend + if backend + in { + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.PALLAS, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + } + else _Backend.TORCH_SDPA + ) self.attn_backend, self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( self.attn_backend, diff --git a/aphrodite/attention/ops/vit_attn_wrappers.py b/aphrodite/attention/ops/vit_attn_wrappers.py index 1a37ccded5..d9218edde4 100644 --- a/aphrodite/attention/ops/vit_attn_wrappers.py +++ b/aphrodite/attention/ops/vit_attn_wrappers.py @@ -62,7 +62,7 @@ def flash_attn_maxseqlen_wrapper( if use_upstream_fa: from flash_attn import flash_attn_varlen_func else: - from aphrodite.aphrodite_flash_attn import flash_attn_varlen_func + from aphrodite.attention.utils.fa_utils import flash_attn_varlen_func q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = flash_attn_varlen_func( q, diff --git a/aphrodite/modeling/models/qwen2_5_vl.py b/aphrodite/modeling/models/qwen2_5_vl.py index c66039ce3e..352cb92c3a 100644 --- a/aphrodite/modeling/models/qwen2_5_vl.py +++ b/aphrodite/modeling/models/qwen2_5_vl.py @@ -327,6 +327,8 @@ def __init__( if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: self.use_upstream_fa = True + if current_platform.is_xpu(): + self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA, @@ -793,7 +795,7 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA: + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() elif self.attn_backend == _Backend.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] diff --git a/aphrodite/modeling/models/qwen2_vl.py b/aphrodite/modeling/models/qwen2_vl.py index a8e4df63b4..73207be26f 100644 --- a/aphrodite/modeling/models/qwen2_vl.py +++ b/aphrodite/modeling/models/qwen2_vl.py @@ -710,7 +710,7 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA: + if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/aphrodite/platforms/xpu.py b/aphrodite/platforms/xpu.py index 84cc67b2fa..d5d6477a27 100644 --- a/aphrodite/platforms/xpu.py +++ b/aphrodite/platforms/xpu.py @@ -110,6 +110,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory + @classmethod + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: + from aphrodite.attention.backends.registry import _Backend + + return _Backend.FLASH_ATTN + @classmethod def inference_mode(cls): return torch.no_grad()