Skip to content

Commit ab78777

Browse files
committed
Fix
1 parent e801f7c commit ab78777

12 files changed

Lines changed: 393 additions & 77 deletions

flash-attn2/build.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ src = [
206206
"flash_attn_xpu/src/flash_bwd_hdim192_fix.cpp",
207207
"flash_attn_xpu/src/flash_bwd_hdim256_fix.cpp",
208208
"flash_attn_xpu/src/flash_bwd_hdim512_fix.cpp",
209+
"flash_attn_xpu/src/fmha_fwd_kvcache_impl.hpp",
210+
"flash_attn_xpu/src/flash_kvcache_hdim32_fix.cpp",
211+
"flash_attn_xpu/src/flash_kvcache_hdim64_fix.cpp",
212+
"flash_attn_xpu/src/flash_kvcache_hdim96_fix.cpp",
213+
"flash_attn_xpu/src/flash_kvcache_hdim128_fix.cpp",
214+
"flash_attn_xpu/src/flash_kvcache_hdim160_fix.cpp",
215+
"flash_attn_xpu/src/flash_kvcache_hdim192_fix.cpp",
216+
"flash_attn_xpu/src/flash_kvcache_hdim256_fix.cpp",
217+
"flash_attn_xpu/src/kernel/fmha_fwd_kvcache_kernel.hpp",
209218
"flash_attn_xpu/src/fmha_utils.hpp",
210219
"flash_attn_xpu/src/collective/fmha_fusion.hpp",
211220
"flash_attn_xpu/src/collective/copy_block_slm.hpp",

flash-attn2/flash_attn_xpu/flash_api.cpp

Lines changed: 276 additions & 57 deletions
Large diffs are not rendered by default.

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim128_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head128,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim160_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head160,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim192_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head192,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim256_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head256,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim32_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head32,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim64_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head64,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_kvcache_hdim96_fix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ template void kvcache_policy_dispatch<
1111
sycl::queue& queue,
1212
CutlassType cuType,
1313
const fmha_fwd_kvcache_args_t& args);
14+
15+
// KVCache paged mode: IsVarLen=0, IsPaged=1
16+
template void kvcache_policy_dispatch<
17+
prefill_policy_head96,
18+
1, // PipelineStages
19+
0, // IsVarLen=0 (fixed length)
20+
1>( // IsPaged=1 (paged)
21+
sycl::queue& queue,
22+
CutlassType cuType,
23+
const fmha_fwd_kvcache_args_t& args);

flash-attn2/flash_attn_xpu/src/fmha_fwd_kvcache_impl.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ struct KVCacheKernelLauncher {
140140
{args.sm_scale,
141141
static_cast<int*>(args.block_table),
142142
args.page_block_size,
143-
0, // max_blocks_per_seq - set below if paged
143+
args.max_blocks_per_seq, // max_pages_per_seq for paged KV
144144
args.seqlen_k, // total_seqlen_kv
145145
args.window_size_left,
146146
args.window_size_right,
@@ -372,25 +372,35 @@ inline void cutlass_fmha_fwd_kvcache_impl(
372372
const fmha_fwd_kvcache_args_t& args,
373373
CutlassType cuType) {
374374

375-
// Dispatch based on head size
376-
// Uses explicitly instantiated kvcache_policy_dispatch<policy, stages, IsVarLen=0, IsPaged=0>
375+
// Dispatch based on head size and paged KV mode
377376
int head_size = args.head_size;
377+
bool is_paged = args.is_paged;
378+
379+
// Helper macro to dispatch both paged and non-paged variants
380+
#define DISPATCH_HEAD_SIZE(policy) \
381+
if (is_paged) { \
382+
kvcache_policy_dispatch<policy, 1, 0, 1>(queue, cuType, args); \
383+
} else { \
384+
kvcache_policy_dispatch<policy, 1, 0, 0>(queue, cuType, args); \
385+
}
378386

379387
if (head_size <= 32) {
380-
kvcache_policy_dispatch<prefill_policy_head32, 1, 0, 0>(queue, cuType, args);
388+
DISPATCH_HEAD_SIZE(prefill_policy_head32);
381389
} else if (head_size <= 64) {
382-
kvcache_policy_dispatch<prefill_policy_head64, 1, 0, 0>(queue, cuType, args);
390+
DISPATCH_HEAD_SIZE(prefill_policy_head64);
383391
} else if (head_size <= 96) {
384-
kvcache_policy_dispatch<prefill_policy_head96, 1, 0, 0>(queue, cuType, args);
392+
DISPATCH_HEAD_SIZE(prefill_policy_head96);
385393
} else if (head_size <= 128) {
386-
kvcache_policy_dispatch<prefill_policy_head128, 1, 0, 0>(queue, cuType, args);
394+
DISPATCH_HEAD_SIZE(prefill_policy_head128);
387395
} else if (head_size <= 160) {
388-
kvcache_policy_dispatch<prefill_policy_head160, 1, 0, 0>(queue, cuType, args);
396+
DISPATCH_HEAD_SIZE(prefill_policy_head160);
389397
} else if (head_size <= 192) {
390-
kvcache_policy_dispatch<prefill_policy_head192, 1, 0, 0>(queue, cuType, args);
398+
DISPATCH_HEAD_SIZE(prefill_policy_head192);
391399
} else if (head_size <= 256) {
392-
kvcache_policy_dispatch<prefill_policy_head256, 1, 0, 0>(queue, cuType, args);
400+
DISPATCH_HEAD_SIZE(prefill_policy_head256);
393401
} else {
394402
CUTLASS_ASSERT(false && "Unsupported head size for kvcache");
395403
}
404+
405+
#undef DISPATCH_HEAD_SIZE
396406
}

0 commit comments

Comments
 (0)