-
-
Notifications
You must be signed in to change notification settings - Fork 200
[mm] vision attention backend for XPU #1584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -260,21 +260,23 @@ def reshape_and_cache_flash( | |
|
|
||
| @staticmethod | ||
| def flash_attn_varlen_func( | ||
| out: torch.Tensor, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens_q: torch.Tensor, | ||
| seqused_k: torch.Tensor, # we don't support this in ipex kernel | ||
| max_seqlen_q: int, | ||
| max_seqlen_k: int, | ||
| softmax_scale: float, | ||
| causal: bool, | ||
| block_table: torch.Tensor, | ||
| alibi_slopes: torch.Tensor | None, | ||
| softmax_scale: float | None = None, | ||
| causal: bool = False, | ||
| out: torch.Tensor | None = None, | ||
| block_table: torch.Tensor | None = None, | ||
| alibi_slopes: torch.Tensor | None = None, | ||
| window_size: list[int] | None = None, | ||
| softcap: float | None = 0.0, | ||
| seqused_k: torch.Tensor | None = None, | ||
| cu_seqlens_k: torch.Tensor | None = None, | ||
| # passed in qwen vl | ||
| dropout_p: float = 0.0, | ||
| # The following parameters are not used in ipex kernel currently, | ||
| # we keep API compatible to CUDA's. | ||
| scheduler_metadata=None, | ||
|
|
@@ -285,31 +287,60 @@ def flash_attn_varlen_func( | |
| num_splits=0, | ||
| s_aux: torch.Tensor | None = None, | ||
| ): | ||
| if out is None: | ||
| out = torch.empty(q.shape, dtype=q.dtype, device=q.device) | ||
|
Comment on lines
+290
to
+291
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initializing |
||
| real_window_size: tuple[int, int] | ||
| if window_size is None: | ||
| real_window_size = (-1, -1) | ||
| else: | ||
| assert len(window_size) == 2 | ||
| real_window_size = (window_size[0], window_size[1]) | ||
| return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( | ||
| out, | ||
| q.contiguous(), | ||
| k, | ||
| v, | ||
| cu_seqlens_q, | ||
| seqused_k, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| softmax_scale, | ||
| causal, | ||
| block_table, | ||
| alibi_slopes, | ||
| softcap=softcap, | ||
| window_size_left=real_window_size[0], | ||
| window_size_right=real_window_size[1], | ||
| k_scale=1.0, | ||
| v_scale=1.0, | ||
| ) | ||
| if block_table is None: | ||
| assert cu_seqlens_k is not None, "cu_seqlens_k can't be None when calling varlen_attention." | ||
| if softmax_scale is None: | ||
| softmax_scale = q.shape[-1] ** (-0.5) | ||
| ipex_ops.varlen_attention( | ||
| q.contiguous(), | ||
| k.contiguous(), | ||
| v.contiguous(), | ||
| out, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| 0.0, | ||
| softmax_scale, | ||
| False, | ||
| causal, | ||
| False, | ||
| None, | ||
| real_window_size[0], | ||
| real_window_size[1], | ||
| -1, | ||
| ) | ||
|
Comment on lines
+303
to
+321
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| return out | ||
| else: | ||
| return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( | ||
| out, | ||
| q.contiguous(), | ||
| k, | ||
| v, | ||
| cu_seqlens_q, | ||
| seqused_k, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| softmax_scale, | ||
| causal, | ||
| block_table, | ||
| alibi_slopes, | ||
| sink=s_aux, | ||
| softcap=softcap, | ||
| window_size_left=real_window_size[0], | ||
| window_size_right=real_window_size[1], | ||
| k_scale=1.0, | ||
| v_scale=1.0, | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def get_scheduler_metadata( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,6 +100,9 @@ def maybe_get_vit_flash_attn_backend( | |
| if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(torch.get_default_dtype()): | ||
| attn_backend = _Backend.FLASH_ATTN | ||
| use_upstream_fa = True | ||
| elif current_platform.is_xpu(): | ||
| assert attn_backend == _Backend.FLASH_ATTN, "XPU platform only supports FLASH_ATTN as vision attention backend." | ||
| use_upstream_fa = False | ||
|
Comment on lines
+103
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| else: | ||
| return _Backend.TORCH_SDPA, None | ||
|
|
||
|
|
@@ -110,7 +113,7 @@ def maybe_get_vit_flash_attn_backend( | |
| if use_upstream_fa: | ||
| from flash_attn import flash_attn_varlen_func | ||
| else: | ||
| from aphrodite.aphrodite_flash_attn import flash_attn_varlen_func | ||
| from aphrodite.attention.utils.fa_utils import flash_attn_varlen_func | ||
| else: | ||
| flash_attn_varlen_func = None | ||
|
|
||
|
|
@@ -473,22 +476,18 @@ def __init__( | |
| # If aphrodite native fa is selected, we use it directly. | ||
| use_upstream_fa = False | ||
|
|
||
| if current_platform.is_xpu(): | ||
| # currently, only torch_sdpa is supported on xpu | ||
| self.attn_backend = _Backend.TORCH_SDPA | ||
| else: | ||
| self.attn_backend = ( | ||
| backend | ||
| if backend | ||
| in { | ||
| _Backend.TORCH_SDPA, | ||
| _Backend.XFORMERS, | ||
| _Backend.PALLAS, | ||
| _Backend.ROCM_AITER_FA, | ||
| _Backend.FLASH_ATTN, | ||
| } | ||
| else _Backend.TORCH_SDPA | ||
| ) | ||
| self.attn_backend = ( | ||
| backend | ||
| if backend | ||
| in { | ||
| _Backend.TORCH_SDPA, | ||
| _Backend.XFORMERS, | ||
| _Backend.PALLAS, | ||
| _Backend.ROCM_AITER_FA, | ||
| _Backend.FLASH_ATTN, | ||
| } | ||
| else _Backend.TORCH_SDPA | ||
| ) | ||
|
Comment on lines
+479
to
+490
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The removal of the |
||
|
|
||
| self.attn_backend, self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( | ||
| self.attn_backend, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -327,6 +327,8 @@ def __init__( | |
|
|
||
| if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: | ||
| self.use_upstream_fa = True | ||
| if current_platform.is_xpu(): | ||
| self.use_upstream_fa = False | ||
|
Comment on lines
+330
to
+331
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting |
||
| self.is_flash_attn_backend = self.attn_backend in { | ||
| _Backend.FLASH_ATTN, | ||
| _Backend.ROCM_AITER_FA, | ||
|
|
@@ -793,7 +795,7 @@ def compute_attn_mask_seqlen( | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| max_seqlen = torch.zeros([], device=cu_seqlens.device) | ||
| seqlens = torch.zeros(1, device=cu_seqlens.device) | ||
| if self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA: | ||
| if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: | ||
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() | ||
| elif self.attn_backend == _Backend.XFORMERS: | ||
| seqlens = cu_seqlens[1:] - cu_seqlens[:-1] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,6 +110,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: | |
| device_props = torch.xpu.get_device_properties(device_id) | ||
| return device_props.total_memory | ||
|
|
||
| @classmethod | ||
| def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: | ||
| from aphrodite.attention.backends.registry import _Backend | ||
|
|
||
| return _Backend.FLASH_ATTN | ||
|
Comment on lines
+114
to
+117
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| @classmethod | ||
| def inference_mode(cls): | ||
| return torch.no_grad() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reordering of parameters in
flash_attn_varlen_funcmakes the function signature less intuitive. Whileoutis now optional, it's generally good practice to keep required parameters before optional ones. Movingoutto the end, or at least after all other required parameters, would improve readability and maintain consistency with common Python function signature conventions.