Skip to content

Commit 7f7267d

Browse files
committed
Add fused rotary support for XPU kvcache
1 parent 468dd6d commit 7f7267d

25 files changed

Lines changed: 783 additions & 74 deletions

flash-attn2/flash_attn_xpu/flash_api.cpp

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -592,29 +592,19 @@ mha_fwd_kvcache(
592592
CHECK_DEVICE(seqlens_k);
593593
}
594594

595-
// Handle rotary embedding (pre-process in-place before kernel)
596-
if (rotary_cos_.has_value()) {
595+
at::Tensor rotary_cos, rotary_sin;
596+
int rotary_dim = 0;
597+
const bool has_rotary = rotary_cos_.has_value();
598+
if (has_rotary) {
597599
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key/value must also be provided");
598-
auto rotary_cos = rotary_cos_.value();
599-
auto rotary_sin = rotary_sin_.value();
600+
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
601+
rotary_cos = ensure_contiguous(rotary_cos_.value());
602+
rotary_sin = ensure_contiguous(rotary_sin_.value());
600603
CHECK_DEVICE(rotary_cos); CHECK_DEVICE(rotary_sin);
601-
int rotary_dim = rotary_cos.size(1) * 2;
604+
rotary_dim = rotary_cos.size(1) * 2;
602605
TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim");
603606
TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
604607
TORCH_CHECK(rotary_cos.scalar_type() == q_dtype && rotary_sin.scalar_type() == q_dtype);
605-
606-
std::optional<at::Tensor> seqlen_offsets_opt;
607-
if (seqlens_k_.has_value()) { seqlen_offsets_opt = seqlens_k; }
608-
609-
bool is_local = (window_size_left >= 0);
610-
if (is_causal || is_local) {
611-
apply_rotary_emb_inplace(q_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved);
612-
} else {
613-
auto q_shape = q_padded.sizes();
614-
auto q_reshaped = q_padded.view({q_shape[0], 1, q_shape[1] * q_shape[2], q_shape[3]});
615-
apply_rotary_emb_inplace(q_reshaped, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved);
616-
}
617-
apply_rotary_emb_inplace(k_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved);
618608
}
619609

620610
at::Tensor cache_batch_idx;
@@ -638,21 +628,30 @@ mha_fwd_kvcache(
638628
// - Always prefer kernel-fused scatter (passes knew/vnew to the kernel,
639629
// which writes them in-place during the prologue). This avoids any
640630
// host sync and works for both contiguous and paged caches.
641-
// - Fall back to API-layer scatter only when fusion is impossible:
642-
// * needs_padding: the cache pad is a separate buffer, so the
643-
// in-kernel writer would write to the padded copy, not the user
644-
// tensor; do the scatter on the user tensor and re-pad.
645-
// * rotary_cos: the rotary application happened on the padded
646-
// buffer; we need to slice off the padding before scattering to
647-
// the user cache. (Kernel-fused scatter copies the padded buffer
648-
// instead, which is wrong.)
631+
// - Fall back to API-layer scatter only when padding is needed: the
632+
// padded cache is a separate buffer, so the in-kernel writer would
633+
// not update the user tensor.
649634
bool fuse_knew = k_.has_value() && seqlen_new > 0
650-
&& !needs_padding && !rotary_cos_.has_value();
635+
&& !needs_padding;
636+
if (has_rotary && !fuse_knew) {
637+
std::optional<at::Tensor> seqlen_offsets_opt;
638+
if (seqlens_k_.has_value()) { seqlen_offsets_opt = seqlens_k; }
639+
640+
bool is_local = (window_size_left >= 0);
641+
if (is_causal || is_local) {
642+
apply_rotary_emb_inplace(q_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved);
643+
} else {
644+
auto q_shape = q_padded.sizes();
645+
auto q_reshaped = q_padded.view({q_shape[0], 1, q_shape[1] * q_shape[2], q_shape[3]});
646+
apply_rotary_emb_inplace(q_reshaped, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved);
647+
}
648+
apply_rotary_emb_inplace(k_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved);
649+
}
651650
if (k_.has_value() && seqlen_new > 0 && !fuse_knew) {
652651
auto seqlens_cpu = seqlens_k.to(torch::kCPU);
653652
auto seqlens_accessor = seqlens_cpu.accessor<int32_t, 1>();
654653

655-
at::Tensor k_for_cache = rotary_cos_.has_value()
654+
at::Tensor k_for_cache = has_rotary
656655
? k_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
657656
torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}).contiguous()
658657
: ensure_contiguous(k_.value());
@@ -722,13 +721,20 @@ mha_fwd_kvcache(
722721
block_table_opt = block_table;
723722
}
724723

724+
std::optional<at::Tensor> rotary_cos_opt, rotary_sin_opt;
725+
if (fuse_knew && has_rotary) {
726+
rotary_cos_opt = rotary_cos;
727+
rotary_sin_opt = rotary_sin;
728+
}
729+
725730
cutlass_fmha_fwd_kvcache_impl(
726731
queue,
727732
q_padded, kcache_padded, vcache_padded,
728733
out, softmax_lse,
729734
seqlens_k, cache_batch_idx_opt, leftpad_k_opt,
730735
knew_opt, vnew_opt,
731-
block_table_opt, seqlen_k,
736+
block_table_opt, rotary_cos_opt, rotary_sin_opt,
737+
fuse_knew ? rotary_dim : 0, is_rotary_interleaved, seqlen_k,
732738
softmax_scale, window_size_left, window_size_right,
733739
is_causal, is_local);
734740

flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,48 @@ namespace cutlass::fmha::collective {
2222

2323
using namespace cute;
2424

25+
template <typename Element, typename RotaryElement>
26+
CUTLASS_DEVICE Element apply_rotary_scalar(
27+
Element x,
28+
Element x_pair,
29+
const RotaryElement* cos,
30+
const RotaryElement* sin,
31+
int position,
32+
int dim,
33+
int rotary_dim,
34+
bool interleaved) {
35+
if (rotary_dim == 0 || dim >= rotary_dim) {
36+
return x;
37+
}
38+
39+
int half_rotary = rotary_dim / 2;
40+
int cos_sin_idx = interleaved ? dim / 2
41+
: (dim < half_rotary ? dim : dim - half_rotary);
42+
bool is_second = interleaved ? (dim % 2) : (dim >= half_rotary);
43+
44+
float x_f = static_cast<float>(x);
45+
float x_pair_f = static_cast<float>(x_pair);
46+
float c = static_cast<float>(cos[position * half_rotary + cos_sin_idx]);
47+
float s = static_cast<float>(sin[position * half_rotary + cos_sin_idx]);
48+
float rotated = is_second ? x_pair_f * s + x_f * c
49+
: x_f * c - x_pair_f * s;
50+
return static_cast<Element>(rotated);
51+
}
52+
53+
CUTLASS_DEVICE int rotary_pair_dim(
54+
int dim,
55+
int rotary_dim,
56+
bool interleaved) {
57+
if (dim >= rotary_dim) {
58+
return dim;
59+
}
60+
if (interleaved) {
61+
return dim ^ 1;
62+
}
63+
int half_rotary = rotary_dim / 2;
64+
return dim < half_rotary ? dim + half_rotary : dim - half_rotary;
65+
}
66+
2567
/////////////////////////////////////////////////////////////////////////////////////////////////
2668
//
2769
// FMHAFwdMainloopTraits: common type aliases derived from TiledMMA / VTiles.

flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ template <
4747
class TensorV_,
4848
class TiledCopyQ_ = void,
4949
class TiledCopyK_ = void,
50-
class TiledCopyV_ = void>
50+
class TiledCopyV_ = void,
51+
bool HasRotary_ = false>
5152
struct FMHAFwdMainloopXe2 {
5253
static_assert(
5354
cutlass::detail::dependent_false<DispatchPolicy_>,
@@ -70,7 +71,8 @@ template <
7071
class TensorV_,
7172
class TiledCopyQ_,
7273
class TiledCopyK_,
73-
class TiledCopyV_>
74+
class TiledCopyV_,
75+
bool HasRotary_>
7476
struct FMHAFwdMainloopXe2<
7577
Xe2<Stages>,
7678
CausalMask_,
@@ -85,7 +87,8 @@ struct FMHAFwdMainloopXe2<
8587
TensorV_,
8688
TiledCopyQ_,
8789
TiledCopyK_,
88-
TiledCopyV_> {
90+
TiledCopyV_,
91+
HasRotary_> {
8992

9093
// Pull in common type aliases from the shared traits.
9194
using Traits = FMHAFwdMainloopTraits<
@@ -123,6 +126,7 @@ struct FMHAFwdMainloopXe2<
123126
static constexpr bool LocalMask = LocalMask_;
124127
static constexpr bool HasDropout = HasDropout_;
125128
static constexpr bool PagedKV = PagedKV_;
129+
static constexpr bool HasRotary = HasRotary_;
126130

127131
// User-facing arguments
128132
struct Arguments {
@@ -142,6 +146,10 @@ struct FMHAFwdMainloopXe2<
142146
int page_size = 0;
143147
int max_pages_per_seq = 0;
144148
int total_seqlen_kv = 0;
149+
const typename TensorQ::element_type* rotary_cos = nullptr;
150+
const typename TensorQ::element_type* rotary_sin = nullptr;
151+
int rotary_dim = 0;
152+
bool is_rotary_interleaved = true;
145153
};
146154

147155
struct LocalMaskFields {
@@ -165,6 +173,14 @@ struct FMHAFwdMainloopXe2<
165173
};
166174
struct EmptyPaged {};
167175

176+
struct RotaryFields {
177+
const typename TensorQ::element_type* rotary_cos = nullptr;
178+
const typename TensorQ::element_type* rotary_sin = nullptr;
179+
int rotary_dim = 0;
180+
bool is_rotary_interleaved = true;
181+
};
182+
struct EmptyRotary {};
183+
168184
// Kernel-facing parameters
169185
struct Params {
170186
ElementS scale;
@@ -174,6 +190,8 @@ struct FMHAFwdMainloopXe2<
174190
dropout_fields;
175191
[[no_unique_address]] conditional_t<PagedKV, PagedKVFields, EmptyPaged>
176192
paged;
193+
[[no_unique_address]] conditional_t<HasRotary, RotaryFields, EmptyRotary>
194+
rotary;
177195
};
178196

179197
// SLM data
@@ -209,6 +227,10 @@ struct FMHAFwdMainloopXe2<
209227
p.paged = {args.ptr_page_table, args.page_size,
210228
args.max_pages_per_seq, args.total_seqlen_kv};
211229
}
230+
if constexpr (HasRotary) {
231+
p.rotary = {args.rotary_cos, args.rotary_sin,
232+
args.rotary_dim, args.is_rotary_interleaved};
233+
}
212234
return p;
213235
}
214236

@@ -236,7 +258,9 @@ struct FMHAFwdMainloopXe2<
236258
int& tile_row_idx,
237259
const int& rows_of_maxima,
238260
int head_q,
239-
int num_heads) {
261+
int num_heads,
262+
int q_offset_sg,
263+
int rotary_base) {
240264
using namespace sycl::ext::oneapi::this_work_item;
241265

242266
auto tile_shape_v =
@@ -387,6 +411,38 @@ struct FMHAFwdMainloopXe2<
387411
CUTLASS_PRAGMA_UNROLL
388412
for (int D = 0; D < size<4>(tKgK); D++) {
389413
copy(copy_q, tQgQ(_, _, _, D), tQrQ);
414+
if constexpr (HasRotary) {
415+
if (params.rotary.rotary_dim > 0 &&
416+
params.rotary.rotary_cos != nullptr &&
417+
params.rotary.rotary_sin != nullptr) {
418+
auto tQrQ_coords = tQrQ.tv_layout();
419+
int lane_id = static_cast<int>(get_sub_group().get_local_linear_id());
420+
int q_tile_base = get<0>(blk_qv) * get<0>(TileShapeQK{}) + q_offset_sg;
421+
int dim_tile_base = D * get<2>(TileShapeQK{});
422+
CUTLASS_PRAGMA_UNROLL
423+
for (int i = 0; i < tQrQ.size(); ++i) {
424+
auto value_coord = idx2crd(
425+
i, make_shape(
426+
get<1>(shape(tQrQ_coords)),
427+
get<2>(shape(tQrQ_coords))));
428+
auto coord = tQrQ_coords(
429+
make_coord(lane_id, get<0>(value_coord), get<1>(value_coord)));
430+
int row = q_tile_base + get<0>(coord);
431+
int dim = dim_tile_base + get<1>(coord);
432+
if (row < seq_len_qo && dim < params.rotary.rotary_dim) {
433+
int pair_dim = rotary_pair_dim(
434+
dim, params.rotary.rotary_dim,
435+
params.rotary.is_rotary_interleaved);
436+
int position = rotary_base + ((CausalMask || LocalMask) ? row : 0);
437+
tQrQ(i) = apply_rotary_scalar(
438+
tQrQ(i), Q_2D(row, pair_dim), params.rotary.rotary_cos,
439+
params.rotary.rotary_sin, position, dim,
440+
params.rotary.rotary_dim,
441+
params.rotary.is_rotary_interleaved);
442+
}
443+
}
444+
}
445+
}
390446
copy(copy_k, tKgK_cache(_, _, _, D), tKrK);
391447
reorder(tQrQ, tSrQ);
392448
reorder(tKrK, tSrK);

flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,35 @@ template void policy_dispatch_${dtype}<
113113
0, 1>(
114114
sycl::queue& queue,
115115
const fmha_fwd_args_t& args);
116+
117+
// Rotary kvcache variants keep rotary code out of non-rotary kernels.
118+
template void policy_dispatch_${dtype}<
119+
prefill_policy_head${hdim},
120+
PipelineStages_Prefill,
121+
0, 0, true>(
122+
sycl::queue& queue,
123+
const fmha_fwd_args_t& args);
124+
125+
template void policy_dispatch_${dtype}<
126+
decode_policy_head${hdim},
127+
PipelineStages_Decode,
128+
0, 0, true>(
129+
sycl::queue& queue,
130+
const fmha_fwd_args_t& args);
131+
132+
template void policy_dispatch_${dtype}<
133+
prefill_policy_head${hdim},
134+
PipelineStages_Prefill,
135+
0, 1, true>(
136+
sycl::queue& queue,
137+
const fmha_fwd_args_t& args);
138+
139+
template void policy_dispatch_${dtype}<
140+
decode_paged_policy_head${hdim},
141+
PipelineStages_Decode,
142+
0, 1, true>(
143+
sycl::queue& queue,
144+
const fmha_fwd_args_t& args);
116145
ENDFILE
117146
echo " Created flash_fwd_hdim${hdim}_kvcache_paged_${dtype}.cpp"
118147
done

flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,32 @@ template void policy_dispatch_bf16<
1818
0, 1>(
1919
sycl::queue& queue,
2020
const fmha_fwd_args_t& args);
21+
22+
// Rotary kvcache variants keep rotary code out of non-rotary kernels.
23+
template void policy_dispatch_bf16<
24+
prefill_policy_head128,
25+
PipelineStages_Prefill,
26+
0, 0, true>(
27+
sycl::queue& queue,
28+
const fmha_fwd_args_t& args);
29+
30+
template void policy_dispatch_bf16<
31+
decode_policy_head128,
32+
PipelineStages_Decode,
33+
0, 0, true>(
34+
sycl::queue& queue,
35+
const fmha_fwd_args_t& args);
36+
37+
template void policy_dispatch_bf16<
38+
prefill_policy_head128,
39+
PipelineStages_Prefill,
40+
0, 1, true>(
41+
sycl::queue& queue,
42+
const fmha_fwd_args_t& args);
43+
44+
template void policy_dispatch_bf16<
45+
decode_paged_policy_head128,
46+
PipelineStages_Decode,
47+
0, 1, true>(
48+
sycl::queue& queue,
49+
const fmha_fwd_args_t& args);

flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,32 @@ template void policy_dispatch_fp16<
1818
0, 1>(
1919
sycl::queue& queue,
2020
const fmha_fwd_args_t& args);
21+
22+
// Rotary kvcache variants keep rotary code out of non-rotary kernels.
23+
template void policy_dispatch_fp16<
24+
prefill_policy_head128,
25+
PipelineStages_Prefill,
26+
0, 0, true>(
27+
sycl::queue& queue,
28+
const fmha_fwd_args_t& args);
29+
30+
template void policy_dispatch_fp16<
31+
decode_policy_head128,
32+
PipelineStages_Decode,
33+
0, 0, true>(
34+
sycl::queue& queue,
35+
const fmha_fwd_args_t& args);
36+
37+
template void policy_dispatch_fp16<
38+
prefill_policy_head128,
39+
PipelineStages_Prefill,
40+
0, 1, true>(
41+
sycl::queue& queue,
42+
const fmha_fwd_args_t& args);
43+
44+
template void policy_dispatch_fp16<
45+
decode_paged_policy_head128,
46+
PipelineStages_Decode,
47+
0, 1, true>(
48+
sycl::queue& queue,
49+
const fmha_fwd_args_t& args);

0 commit comments

Comments
 (0)