Skip to content

Commit c18b75a

Browse files
zhangxinzhangxin
authored andcommitted
issue/1148: PagedAttentionPrefill 添加 KV cache 连续性 guard
1 parent 3f0a98c commit c18b75a

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

python/infinicore/ops/paged_attention_prefill.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
from infinicore.tensor import Tensor
33

44

5+
def _ensure_head_dim_contiguous(tensor: Tensor) -> Tensor:
6+
if tensor.ndim > 0 and tensor.stride(tensor.ndim - 1) != 1:
7+
return tensor.contiguous()
8+
return tensor
9+
10+
511
def paged_attention_prefill(
612
q: Tensor,
713
k_cache: Tensor,
@@ -14,6 +20,8 @@ def paged_attention_prefill(
1420
*,
1521
out: Tensor | None = None,
1622
):
23+
k_cache = _ensure_head_dim_contiguous(k_cache)
24+
v_cache = _ensure_head_dim_contiguous(v_cache)
1725
alibi_ptr = alibi_slopes._underlying if alibi_slopes is not None else None
1826

1927
if out is None:

0 commit comments

Comments
 (0)