Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions aphrodite/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,21 +260,23 @@ def reshape_and_cache_flash(

@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: torch.Tensor | None,
softmax_scale: float | None = None,
causal: bool = False,
out: torch.Tensor | None = None,
block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None,
softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
Comment on lines 263 to +279

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reordering of parameters in flash_attn_varlen_func makes the function signature less intuitive. While out is now optional, it's generally good practice to keep required parameters before optional ones. Moving out to the end, or at least after all other required parameters, would improve readability and maintain consistency with common Python function signature conventions.

Suggested change
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: torch.Tensor | None,
softmax_scale: float | None = None,
causal: bool = False,
out: torch.Tensor | None = None,
block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None,
softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
# passed in qwen vl
dropout_p: float = 0.0,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float | None = None,
causal: bool = False,
block_table: torch.Tensor | None = None,
alibi_slopes: torch.Tensor | None = None,
window_size: list[int] | None = None,
softcap: float | None = 0.0,
seqused_k: torch.Tensor | None = None,
cu_seqlens_k: torch.Tensor | None = None,
dropout_p: float = 0.0,
out: torch.Tensor | None = None,

# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
Expand All @@ -285,31 +287,60 @@ def flash_attn_varlen_func(
num_splits=0,
s_aux: torch.Tensor | None = None,
):
if out is None:
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
Comment on lines +290 to +291

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Initializing out with torch.empty inside the function when out is None is a good approach. However, consider if q.shape, q.dtype, and q.device are always the correct attributes to use for out's initialization, especially if out might have a different expected shape or device in some edge cases not covered by the current logic. If out is intended to be the same shape as q, this is fine.

real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
if block_table is None:
assert cu_seqlens_k is not None, "cu_seqlens_k can't be None when calling varlen_attention."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
ipex_ops.varlen_attention(
q.contiguous(),
k.contiguous(),
v.contiguous(),
out,
cu_seqlens_q,
cu_seqlens_k,
None,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale,
False,
causal,
False,
None,
real_window_size[0],
real_window_size[1],
-1,
)
Comment on lines +303 to +321

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ipex_ops.varlen_attention call has several hardcoded values (e.g., 0.0 for dropout, False for is_causal, False for return_softmax, None for attn_mask, -1 for num_splits). While these might be the default or desired values for this specific use case, it's generally better to pass them as explicit arguments or derive them from existing parameters if they are configurable elsewhere. This improves clarity and makes the function more flexible for future changes.

return out
else:
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
sink=s_aux,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)

@staticmethod
def get_scheduler_metadata(
Expand Down
33 changes: 16 additions & 17 deletions aphrodite/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(torch.get_default_dtype()):
attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
elif current_platform.is_xpu():
assert attn_backend == _Backend.FLASH_ATTN, "XPU platform only supports FLASH_ATTN as vision attention backend."
use_upstream_fa = False
Comment on lines +103 to +105

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assert statement here will raise an AssertionError if attn_backend is not _Backend.FLASH_ATTN on XPU. While this ensures the correct backend is used, it might be more user-friendly to raise a custom exception (e.g., ValueError) with a more descriptive message, or to handle this case gracefully by falling back to a default backend if possible, rather than crashing the program. This depends on the expected behavior and error handling strategy for the application.

else:
return _Backend.TORCH_SDPA, None

Expand All @@ -110,7 +113,7 @@ def maybe_get_vit_flash_attn_backend(
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from aphrodite.aphrodite_flash_attn import flash_attn_varlen_func
from aphrodite.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None

Expand Down Expand Up @@ -473,22 +476,18 @@ def __init__(
# If aphrodite native fa is selected, we use it directly.
use_upstream_fa = False

if current_platform.is_xpu():
# currently, only torch_sdpa is supported on xpu
self.attn_backend = _Backend.TORCH_SDPA
else:
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
self.attn_backend = (
backend
if backend
in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
}
else _Backend.TORCH_SDPA
)
Comment on lines +479 to +490

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the if current_platform.is_xpu(): block and the direct assignment of self.attn_backend means that the XPU platform no longer explicitly defaults to _Backend.TORCH_SDPA if backend is not FLASH_ATTN. This change aligns with the maybe_get_vit_flash_attn_backend function, which now handles XPU-specific backend selection. Ensure that this change doesn't inadvertently allow unsupported backends on XPU or introduce unexpected behavior if backend is not FLASH_ATTN when running on XPU.


self.attn_backend, self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
self.attn_backend,
Expand Down
2 changes: 1 addition & 1 deletion aphrodite/attention/ops/vit_attn_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def flash_attn_maxseqlen_wrapper(
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from aphrodite.aphrodite_flash_attn import flash_attn_varlen_func
from aphrodite.attention.utils.fa_utils import flash_attn_varlen_func
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
Expand Down
4 changes: 3 additions & 1 deletion aphrodite/modeling/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def __init__(

if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True
if current_platform.is_xpu():
self.use_upstream_fa = False
Comment on lines +330 to +331

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Setting self.use_upstream_fa = False specifically for XPU is a critical change. This ensures that the custom aphrodite Flash Attention implementation is used instead of the upstream flash_attn library on XPU. This is important for compatibility and performance on XPU, but it's crucial to ensure that the custom implementation is fully tested and optimized for XPU to avoid regressions.

self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
Expand Down Expand Up @@ -793,7 +795,7 @@ def compute_attn_mask_seqlen(
) -> tuple[torch.Tensor, torch.Tensor]:
max_seqlen = torch.zeros([], device=cu_seqlens.device)
seqlens = torch.zeros(1, device=cu_seqlens.device)
if self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA:
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
Expand Down
2 changes: 1 addition & 1 deletion aphrodite/modeling/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:

def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if self.attn_backend == _Backend.FLASH_ATTN or self.attn_backend == _Backend.ROCM_AITER_FA:
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
Expand Down
6 changes: 6 additions & 0 deletions aphrodite/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory

@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from aphrodite.attention.backends.registry import _Backend

return _Backend.FLASH_ATTN
Comment on lines +114 to +117

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_vit_attn_backend method for XPU explicitly returns _Backend.FLASH_ATTN. This hardcodes the vision attention backend for XPU to Flash Attention. While this might be the current strategy, consider if there's a need for flexibility in the future to support other backends or to dynamically determine the best backend based on hardware capabilities or user preferences. If not, this explicit setting is clear.


@classmethod
def inference_mode(cls):
return torch.no_grad()
Expand Down