Skip to content

Commit aeb0e1f

Browse files
committed
issue/889 - optimize flash attention performance from kernel
1 parent 4201ea7 commit aeb0e1f

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def application_without_kv_cache(
183183
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
184184
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
185185

186-
for j in range(key.shape[0]):
186+
for j in range(-(-actual_kv_len // key.dtype.shape[0])):
187187

188188
qk = ntl.dot(query_i, ntl.trans(key[j]))
189189

0 commit comments

Comments
 (0)