We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3f0a98c commit c18b75aCopy full SHA for c18b75a
1 file changed
python/infinicore/ops/paged_attention_prefill.py
@@ -2,6 +2,12 @@
2
from infinicore.tensor import Tensor
3
4
5
+def _ensure_head_dim_contiguous(tensor: Tensor) -> Tensor:
6
+ if tensor.ndim > 0 and tensor.stride(tensor.ndim - 1) != 1:
7
+ return tensor.contiguous()
8
+ return tensor
9
+
10
11
def paged_attention_prefill(
12
q: Tensor,
13
k_cache: Tensor,
@@ -14,6 +20,8 @@ def paged_attention_prefill(
14
20
*,
15
21
out: Tensor | None = None,
16
22
):
23
+ k_cache = _ensure_head_dim_contiguous(k_cache)
24
+ v_cache = _ensure_head_dim_contiguous(v_cache)
17
25
alibi_ptr = alibi_slopes._underlying if alibi_slopes is not None else None
18
26
19
27
if out is None:
0 commit comments