Skip to content

Commit 4bd5dbc

Browse files
committed
refine
1 parent 61aa075 commit 4bd5dbc

14 files changed

Lines changed: 435 additions & 63 deletions

flash-attn2/build.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ src = [
195195
"flash_attn_xpu/src/flash_fwd_hdim192_fix.cpp",
196196
"flash_attn_xpu/src/flash_fwd_hdim256_fix.cpp",
197197
"flash_attn_xpu/src/flash_fwd_hdim512_fix.cpp",
198+
"flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged.cpp",
199+
"flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged.cpp",
200+
"flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged.cpp",
201+
"flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged.cpp",
202+
"flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged.cpp",
203+
"flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged.cpp",
204+
"flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged.cpp",
205+
"flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged.cpp",
198206
"flash_attn_xpu/src/fmha_bwd_types.hpp",
199207
"flash_attn_xpu/src/fmha_bwd.hpp",
200208
"flash_attn_xpu/src/fmha_bwd_impl.hpp",

flash-attn2/flash_attn_xpu/flash_api.cpp

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -632,10 +632,22 @@ mha_fwd_kvcache(
632632
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
633633
}
634634

635-
// Write new K/V to cache in-place
636-
// Non-paged without padding: fused in kernel (knew/vnew passed to dispatch)
637-
// Paged or needs-padding: API-layer scatter (kernel fusion not applicable)
638-
bool fuse_knew = k_.has_value() && seqlen_new > 0 && !paged_KV && !needs_padding;
635+
// Write new K/V to cache.
636+
//
637+
// Strategy:
638+
// - Always prefer kernel-fused scatter (passes knew/vnew to the kernel,
639+
// which writes them in-place during the prologue). This avoids any
640+
// host sync and works for both contiguous and paged caches.
641+
// - Fall back to API-layer scatter only when fusion is impossible:
642+
// * needs_padding: the cache pad is a separate buffer, so the
643+
// in-kernel writer would write to the padded copy, not the user
644+
// tensor; do the scatter on the user tensor and re-pad.
645+
// * rotary_cos: the rotary application happened on the padded
646+
// buffer; we need to slice off the padding before scattering to
647+
// the user cache. (Kernel-fused scatter copies the padded buffer
648+
// instead, which is wrong.)
649+
bool fuse_knew = k_.has_value() && seqlen_new > 0
650+
&& !needs_padding && !rotary_cos_.has_value();
639651
if (k_.has_value() && seqlen_new > 0 && !fuse_knew) {
640652
auto seqlens_cpu = seqlens_k.to(torch::kCPU);
641653
auto seqlens_accessor = seqlens_cpu.accessor<int32_t, 1>();
@@ -683,28 +695,8 @@ mha_fwd_kvcache(
683695
seqlens_k = seqlens_k + seqlen_new;
684696
}
685697

686-
// For paged KV, gather to contiguous format
687-
if (paged_KV) {
688-
int num_pages_needed = (seqlen_k + page_block_size - 1) / page_block_size;
689-
auto block_indices = block_table.index({
690-
torch::indexing::Slice(),
691-
torch::indexing::Slice(0, num_pages_needed)
692-
}).flatten();
693-
auto k_gathered = kcache_padded.index_select(0, block_indices.to(torch::kLong));
694-
auto v_gathered = vcache_padded.index_select(0, block_indices.to(torch::kLong));
695-
k_gathered = k_gathered.view({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded});
696-
v_gathered = v_gathered.view({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded});
697-
k_gathered = k_gathered.view({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded});
698-
v_gathered = v_gathered.view({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded});
699-
kcache_padded = k_gathered.index({
700-
torch::indexing::Slice(), torch::indexing::Slice(0, seqlen_k)
701-
}).contiguous();
702-
vcache_padded = v_gathered.index({
703-
torch::indexing::Slice(), torch::indexing::Slice(0, seqlen_k)
704-
}).contiguous();
705-
}
706-
707-
// Dispatch to kernel
698+
// Dispatch to kernel. Paged caches are now passed natively (block_table
699+
// routed straight through to the kernel, no host gather).
708700
auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue();
709701
const bool is_local = (window_size_left >= 0);
710702

@@ -718,19 +710,25 @@ mha_fwd_kvcache(
718710
leftpad_k_opt = leftpad_k;
719711
}
720712

721-
// For non-paged path with new KV, pass knew/vnew for fused scatter in kernel
713+
// For paths where new KV is appended in-kernel, pass knew/vnew through.
722714
std::optional<at::Tensor> knew_opt, vnew_opt;
723715
if (fuse_knew) {
724716
knew_opt = k_padded;
725717
vnew_opt = v_padded;
726718
}
727719

720+
std::optional<at::Tensor> block_table_opt;
721+
if (paged_KV) {
722+
block_table_opt = block_table;
723+
}
724+
728725
cutlass_fmha_fwd_kvcache_impl(
729726
queue,
730727
q_padded, kcache_padded, vcache_padded,
731728
out, softmax_lse,
732729
seqlens_k, cache_batch_idx_opt, leftpad_k_opt,
733730
knew_opt, vnew_opt,
731+
block_table_opt, seqlen_k,
734732
softmax_scale, window_size_left, window_size_right,
735733
is_causal, is_local);
736734

flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,39 @@ ENDFILE
8787
echo " Created flash_fwd_hdim${hdim}_fix.cpp"
8888
done
8989

90+
echo ""
91+
echo "Creating kvcache-paged instantiation files (non-varlen + paged)..."
92+
for hdim in "${HDIMS[@]}"; do
93+
cat > flash_fwd_hdim${hdim}_kvcache_paged.cpp << ENDFILE
94+
#include "fmha_fwd_impl.hpp"
95+
96+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
97+
// Used by mha_fwd_kvcache when block_table is provided.
98+
99+
// Prefill paged
100+
template void policy_dispatch<
101+
prefill_policy_head${hdim},
102+
PipelineStages_Prefill,
103+
0, 1>(
104+
sycl::queue& queue,
105+
CutlassType cuType,
106+
const fmha_fwd_args_t& args);
107+
108+
// Decode paged (smaller K-tile to fit page boundaries)
109+
template void policy_dispatch<
110+
decode_paged_policy_head${hdim},
111+
PipelineStages_Decode,
112+
0, 1>(
113+
sycl::queue& queue,
114+
CutlassType cuType,
115+
const fmha_fwd_args_t& args);
116+
ENDFILE
117+
echo " Created flash_fwd_hdim${hdim}_kvcache_paged.cpp"
118+
done
119+
90120
echo ""
91121
echo "✓ All instantiation files created successfully!"
92122
echo " - ${#HDIMS[@]} varlen files (IsVarLen=1, paged + non-paged)"
93123
echo " - ${#HDIMS[@]} fixed files (IsVarLen=0, decode + prefill)"
94-
echo " Total: $((${#HDIMS[@]} * 2)) files"
124+
echo " - ${#HDIMS[@]} kvcache_paged files (IsVarLen=0, IsPaged=1)"
125+
echo " Total: $((${#HDIMS[@]} * 3)) files"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head128,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head128,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head160,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head160,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head192,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head192,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head256,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head256,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head32,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head32,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head512,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head512,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "fmha_fwd_impl.hpp"
2+
3+
// Non-varlen + paged: IsVarLen=0, IsPaged=1
4+
// Used by mha_fwd_kvcache when block_table is provided.
5+
6+
// Prefill paged
7+
template void policy_dispatch<
8+
prefill_policy_head64,
9+
PipelineStages_Prefill,
10+
0, 1>(
11+
sycl::queue& queue,
12+
CutlassType cuType,
13+
const fmha_fwd_args_t& args);
14+
15+
// Decode paged (smaller K-tile to fit page boundaries)
16+
template void policy_dispatch<
17+
decode_paged_policy_head64,
18+
PipelineStages_Decode,
19+
0, 1>(
20+
sycl::queue& queue,
21+
CutlassType cuType,
22+
const fmha_fwd_args_t& args);

0 commit comments

Comments
 (0)