From 11e58a500e3160d62d5bc2dbd78d7a390c0d87c9 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Mon, 27 Apr 2026 05:54:54 +0000 Subject: [PATCH 1/8] Rebase and change kDropMaskMax --- flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp index 640d43c0..d6a71bad 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp @@ -605,7 +605,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, // Math: dS = scale * P * (mask * rp * dP_dropped - dpsum) // where P is the original softmax output, dP_dropped = dO * V^T, // and dpsum = sum(dO * O) with O having rp scaling from forward. - constexpr int kDropMaskMax = 128; + constexpr int kDropMaskMax = decltype(size(rS))::value; bool drop_keep[kDropMaskMax]; int drop_mask_count = 0; From 3f04867e446fe33e4617f403cb68f845de1df0bc Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Mon, 27 Apr 2026 12:41:17 +0000 Subject: [PATCH 2/8] Rebase --- flash-attn2/build.toml | 1 + flash-attn2/flash_attn_xpu/flash_api.cpp | 341 ++++++++++++++++++ .../flash_attn_xpu/src/fmha_bwd_impl.hpp | 2 +- flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp | 85 +++++ flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp | 18 + .../flash_attn_xpu/src/fmha_fwd_impl.hpp | 10 +- .../flash_attn_xpu/src/fmha_fwd_types.hpp | 16 + .../src/kernel/fmha_fwd_kernel_xe2.hpp | 82 ++++- flash-attn2/flash_attn_xpu/src/rotary.hpp | 118 ++++++ flash-attn2/tests/test_flash_attn.py | 3 +- flash-attn2/torch-ext/torch_binding.cpp | 2 + 11 files changed, 667 insertions(+), 11 deletions(-) create mode 100644 flash-attn2/flash_attn_xpu/src/rotary.hpp diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index 0a4aff81..b979ad45 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -174,6 +174,7 @@ depends = [ src = [ "flash_attn_xpu/flash_api.cpp", "flash_attn_xpu/src/philox.hpp", + "flash_attn_xpu/src/rotary.hpp", "flash_attn_xpu/src/fmha_fwd_types.hpp", "flash_attn_xpu/src/fmha_fwd.hpp", "flash_attn_xpu/src/fmha_fwd_impl.hpp", diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index a4bb4c5c..b7268ec5 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #include "src/fmha_fwd.hpp" #include "src/fmha_bwd.hpp" @@ -466,6 +468,287 @@ mha_varlen_fwd( at::Tensor rng_state; return {out, softmax_lse, S_dmask, rng_state}; } + +#include "src/rotary.hpp" + +std::vector +mha_fwd_kvcache( + at::Tensor &q, + const at::Tensor &kcache, + const at::Tensor &vcache, + std::optional &k_, + std::optional &v_, + std::optional &seqlens_k_, + std::optional &rotary_cos_, + std::optional &rotary_sin_, + std::optional &cache_batch_idx_, + std::optional &leftpad_k_, + std::optional &block_table_, + std::optional &alibi_slopes_, + std::optional &out_, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, + int num_splits) { + auto device_idx = q.device().index(); + compat::select_device(device_idx); + + TORCH_CHECK(!alibi_slopes_.has_value(), + "FlashAttention KVCache on XPU does not support alibi_slopes."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention KVCache only supports fp16 and bf16 data type"); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key cache must have the same dtype"); + TORCH_CHECK(vcache.dtype() == q_dtype, "query and value cache must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + TORCH_CHECK(q.stride(-1) == 1, "Query must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Key cache must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Value cache must have contiguous last dimension"); + + const bool paged_KV = block_table_.has_value(); + at::Tensor block_table; + if (paged_KV) { + TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + } + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int page_block_size = !paged_KV ? 1 : kcache.size(1); + const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; + const int num_heads_k = kcache.size(2); + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256 || head_size_og == 512, + "FlashAttention KVCache only supports head dimension up to 256 or exactly 512"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + const int head_size_padded = round_multiple(head_size_og, 32); + const bool needs_padding = (head_size_og != head_size_padded); + const int pad_size = head_size_padded - head_size_og; + + auto maybe_pad = [&](const at::Tensor& t) -> at::Tensor { + return needs_padding + ? torch::nn::functional::pad(t, torch::nn::functional::PadFuncOptions({0, pad_size})) + : t; + }; + + at::Tensor q_padded = maybe_pad(ensure_contiguous(q)); + at::Tensor kcache_padded = maybe_pad(ensure_contiguous(kcache)); + at::Tensor vcache_padded = maybe_pad(ensure_contiguous(vcache)); + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + if (needs_padding) { out = maybe_pad(out); } + } else { + out = torch::zeros_like(q_padded); + } + + auto opts = q.options(); + auto softmax_lse = torch::full({batch_size, num_heads, seqlen_q}, + -std::numeric_limits::infinity(), + opts.dtype(at::kFloat)); + + // Handle new K/V + at::Tensor k_padded, v_padded; + int seqlen_new = 0; + if (k_.has_value()) { + TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); + TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); + auto k = k_.value(); + auto v = v_.value(); + TORCH_CHECK(k.dtype() == q_dtype && v.dtype() == q_dtype); + CHECK_DEVICE(k); CHECK_DEVICE(v); + TORCH_CHECK(k.stride(-1) == 1 && v.stride(-1) == 1); + seqlen_new = k.size(1); + k_padded = maybe_pad(ensure_contiguous(k)); + v_padded = maybe_pad(ensure_contiguous(v)); + } + + at::Tensor seqlens_k; + if (seqlens_k_.has_value()) { + seqlens_k = seqlens_k_.value(); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + } + + // Handle rotary embedding (pre-process in-place before kernel) + if (rotary_cos_.has_value()) { + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key/value must also be provided"); + auto rotary_cos = rotary_cos_.value(); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_cos); CHECK_DEVICE(rotary_sin); + int rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim"); + TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + TORCH_CHECK(rotary_cos.scalar_type() == q_dtype && rotary_sin.scalar_type() == q_dtype); + + std::optional seqlen_offsets_opt; + if (seqlens_k_.has_value()) { seqlen_offsets_opt = seqlens_k; } + + bool is_local = (window_size_left >= 0); + if (is_causal || is_local) { + apply_rotary_emb_inplace(q_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); + } else { + auto q_shape = q_padded.sizes(); + auto q_reshaped = q_padded.view({q_shape[0], 1, q_shape[1] * q_shape[2], q_shape[3]}); + apply_rotary_emb_inplace(q_reshaped, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); + } + apply_rotary_emb_inplace(k_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); + } + + at::Tensor cache_batch_idx; + if (cache_batch_idx_.has_value()) { + cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32); + } + + at::Tensor leftpad_k; + if (leftpad_k_.has_value()) { + TORCH_CHECK(!paged_KV, "Paged KV and leftpad_k are not supported together"); + leftpad_k = leftpad_k_.value(); + CHECK_DEVICE(leftpad_k); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + } + + // Write new K/V to cache in-place + // Non-paged without padding: fused in kernel (knew/vnew passed to dispatch) + // Paged or needs-padding: API-layer scatter (kernel fusion not applicable) + bool fuse_knew = k_.has_value() && seqlen_new > 0 && !paged_KV && !needs_padding; + if (k_.has_value() && seqlen_new > 0 && !fuse_knew) { + auto seqlens_cpu = seqlens_k.to(torch::kCPU); + auto seqlens_accessor = seqlens_cpu.accessor(); + + at::Tensor k_for_cache = rotary_cos_.has_value() + ? k_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}).contiguous() + : ensure_contiguous(k_.value()); + at::Tensor v_for_cache = ensure_contiguous(v_.value()); + + at::Tensor kc = ensure_contiguous(kcache); + at::Tensor vc = ensure_contiguous(vcache); + + if (paged_KV) { + auto bt_cpu = block_table.to(torch::kCPU); + auto bt_acc = bt_cpu.accessor(); + for (int b = 0; b < batch_size; b++) { + int cache_seqlen = seqlens_accessor[b]; + for (int s = 0; s < seqlen_new; s++) { + int global_pos = cache_seqlen + s; + int page_idx = global_pos / page_block_size; + int page_offset = global_pos % page_block_size; + int block_idx = bt_acc[b][page_idx]; + kc.index({block_idx, page_offset}) = k_for_cache.index({b, s}); + vc.index({block_idx, page_offset}) = v_for_cache.index({b, s}); + } + } + } else { + for (int b = 0; b < batch_size; b++) { + int cache_b = cache_batch_idx_.has_value() + ? cache_batch_idx.index({b}).item() : b; + int cache_seqlen = seqlens_accessor[b]; + int write_start = cache_seqlen; + TORCH_CHECK(write_start + seqlen_new <= seqlen_k, + "Cache overflow: cache_seqlen + seqlen_new > cache capacity"); + kc.index({cache_b, torch::indexing::Slice(write_start, write_start + seqlen_new)}) = + k_for_cache.index({b}); + vc.index({cache_b, torch::indexing::Slice(write_start, write_start + seqlen_new)}) = + v_for_cache.index({b}); + } + } + + kcache_padded = maybe_pad(kc); + vcache_padded = maybe_pad(vc); + seqlens_k = seqlens_k + seqlen_new; + } + + // For paged KV, gather to contiguous format + if (paged_KV) { + int num_pages_needed = (seqlen_k + page_block_size - 1) / page_block_size; + auto block_indices = block_table.index({ + torch::indexing::Slice(), + torch::indexing::Slice(0, num_pages_needed) + }).flatten(); + auto k_gathered = kcache_padded.index_select(0, block_indices.to(torch::kLong)); + auto v_gathered = vcache_padded.index_select(0, block_indices.to(torch::kLong)); + k_gathered = k_gathered.view({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded}); + v_gathered = v_gathered.view({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded}); + k_gathered = k_gathered.view({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded}); + v_gathered = v_gathered.view({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded}); + kcache_padded = k_gathered.index({ + torch::indexing::Slice(), torch::indexing::Slice(0, seqlen_k) + }).contiguous(); + vcache_padded = v_gathered.index({ + torch::indexing::Slice(), torch::indexing::Slice(0, seqlen_k) + }).contiguous(); + } + + // Dispatch to kernel + auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue(); + const bool is_local = (window_size_left >= 0); + + std::optional cache_batch_idx_opt; + if (cache_batch_idx_.has_value()) { + cache_batch_idx_opt = cache_batch_idx; + } + + std::optional leftpad_k_opt; + if (leftpad_k_.has_value()) { + leftpad_k_opt = leftpad_k; + } + + // For non-paged path with new KV, pass knew/vnew for fused scatter in kernel + std::optional knew_opt, vnew_opt; + if (fuse_knew) { + knew_opt = k_padded; + vnew_opt = v_padded; + } + + cutlass_fmha_fwd_kvcache_impl( + queue, + q_padded, kcache_padded, vcache_padded, + out, softmax_lse, + seqlens_k, cache_batch_idx_opt, leftpad_k_opt, + knew_opt, vnew_opt, + softmax_scale, window_size_left, window_size_right, + is_causal, is_local); + + // Update seqlens_k after kernel completes (for fused scatter path) + if (fuse_knew) { + seqlens_k = seqlens_k + seqlen_new; + } + + if (needs_padding) { + out = out.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}) + .contiguous(); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, softmax_lse}; +} + } // namespace FLASH_NAMESPACE // std::tuple @@ -633,4 +916,62 @@ mha_bwd(const torch::Tensor &dout, gen_, rng_opt ); +} + +std::vector +mha_fwd_kvcache( + const torch::Tensor &q, + const torch::Tensor &kcache, + const torch::Tensor &vcache, + const c10::optional &k_, + const c10::optional &v_, + const c10::optional &seqlens_k_, + const c10::optional &rotary_cos_, + const c10::optional &rotary_sin_, + const c10::optional &cache_batch_idx_, + const c10::optional &leftpad_k_, + const c10::optional &block_table_, + const c10::optional &alibi_slopes_, + const c10::optional &out_, + const double softmax_scale, + bool is_causal, + const int64_t window_size_left, + const int64_t window_size_right, + const double softcap, + bool is_rotary_interleaved, + const int64_t num_splits) { + // Convert c10::optional -> std::optional for the internal API + auto to_std_opt = [](const c10::optional& opt) -> std::optional { + return opt.has_value() ? std::optional(opt.value()) : std::nullopt; + }; + auto to_std_opt_const = [](const c10::optional& opt) -> std::optional { + return opt.has_value() ? std::optional(opt.value()) : std::nullopt; + }; + + at::Tensor q_mut = q; + auto k_opt = to_std_opt_const(k_); + auto v_opt = to_std_opt_const(v_); + auto seqlens_opt = to_std_opt_const(seqlens_k_); + auto rotary_cos_opt = to_std_opt_const(rotary_cos_); + auto rotary_sin_opt = to_std_opt_const(rotary_sin_); + auto cache_batch_idx_opt = to_std_opt_const(cache_batch_idx_); + auto leftpad_k_opt = to_std_opt_const(leftpad_k_); + auto block_table_opt = to_std_opt(block_table_); + auto alibi_opt = to_std_opt(alibi_slopes_); + auto out_opt = to_std_opt(out_); + + return FLASH_NAMESPACE::mha_fwd_kvcache( + q_mut, kcache, vcache, + k_opt, v_opt, seqlens_opt, + rotary_cos_opt, rotary_sin_opt, + cache_batch_idx_opt, leftpad_k_opt, + block_table_opt, alibi_opt, out_opt, + static_cast(softmax_scale), + is_causal, + static_cast(window_size_left), + static_cast(window_size_right), + static_cast(softcap), + is_rotary_interleaved, + static_cast(num_splits) + ); } \ No newline at end of file diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp index d6a71bad..ea7ac3ba 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp @@ -605,7 +605,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, // Math: dS = scale * P * (mask * rp * dP_dropped - dpsum) // where P is the original softmax output, dP_dropped = dO * V^T, // and dpsum = sum(dO * O) with O having rp scaling from forward. - constexpr int kDropMaskMax = decltype(size(rS))::value; + constexpr int kDropMaskMax = decltype(size<0>(scores))::value * decltype(size<1>(scores))::value; bool drop_keep[kDropMaskMax]; int drop_mask_count = 0; diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp index 7e52e200..f80f4590 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp @@ -306,3 +306,88 @@ void cutlass_fmha_fwd_fix_impl( dispatch_fwd_prefill_by_head(queue, cuType, args, h); } } + +void cutlass_fmha_fwd_kvcache_impl( + sycl::queue& queue, + const at::Tensor& query, + const at::Tensor& kcache, + const at::Tensor& vcache, + at::Tensor& out, + at::Tensor& softmax_lse, + const at::Tensor& seqlens_k, + const std::optional& cache_batch_idx, + const std::optional& cache_leftpad, + const std::optional& knew, + const std::optional& vnew, + float sm_scale, + int window_size_left, + int window_size_right, + bool is_causal, + bool is_local) { + const int batch_size = query.size(0); + const int max_seqlen_q = query.size(1); + const int num_heads_q = query.size(2); + const int head_size = query.size(3); + + const int max_seqlen_k = kcache.size(1); + const int num_heads_kv = kcache.size(2); + + const int total_seqlen_q = batch_size * max_seqlen_q; + const int total_seqlen_k = batch_size * max_seqlen_k; + + normalize_window_params(window_size_left, window_size_right, + is_causal, is_local, max_seqlen_k); + + fmha_fwd_args_t args = { + query.data_ptr(), + kcache.data_ptr(), + vcache.data_ptr(), + out.data_ptr(), + softmax_lse.data_ptr(), + nullptr, // block_table + nullptr, // cu_seqlens_q + nullptr, // cu_seqlens_k + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + sm_scale, + batch_size, + num_heads_q, + num_heads_kv, + head_size, + 0, // max_blocks_per_seq + 0, // block_size + window_size_left, + window_size_right, + false, // is_varlen + false, // is_paged + is_causal, + is_local, + 0.0f, 0, 0, nullptr, nullptr, 0, 0, // dropout & s_dmask defaults + static_cast(seqlens_k.data_ptr()), + cache_batch_idx.has_value() + ? static_cast(cache_batch_idx->data_ptr()) + : nullptr, + cache_leftpad.has_value() + ? static_cast(cache_leftpad->data_ptr()) + : nullptr, + knew.has_value() ? knew->data_ptr() : nullptr, + vnew.has_value() ? vnew->data_ptr() : nullptr, + knew.has_value() ? static_cast(knew->size(1)) : 0, + knew.has_value() ? knew->stride(0) : 0, + knew.has_value() ? knew->stride(2) : 0, + knew.has_value() ? knew->stride(1) : 0, + vnew.has_value() ? vnew->stride(0) : 0, + vnew.has_value() ? vnew->stride(2) : 0, + vnew.has_value() ? vnew->stride(1) : 0}; + + const CutlassType cuType = aten_to_Cutlass_dtype(query); + const int h = args.head_size; + + if (max_seqlen_q == 1) { + dispatch_fwd_decode_by_head(queue, cuType, args, h); + } else { + dispatch_fwd_prefill_by_head(queue, cuType, args, h); + } +} diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp index baa3758d..9cc9fd0e 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp @@ -158,3 +158,21 @@ void cutlass_fmha_fwd_fix_impl( void* s_dmask = nullptr, int seqlen_q_rounded = 0, int seqlen_k_rounded = 0); + +void cutlass_fmha_fwd_kvcache_impl( + sycl::queue& queue, + const at::Tensor& query, + const at::Tensor& kcache, + const at::Tensor& vcache, + at::Tensor& out, + at::Tensor& softmax_lse, + const at::Tensor& seqlens_k, + const std::optional& cache_batch_idx, + const std::optional& cache_leftpad, + const std::optional& knew, + const std::optional& vnew, + float sm_scale, + int window_size_left, + int window_size_right, + bool is_causal, + bool is_local); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp index 5e6d00e4..7f12a1e7 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp @@ -195,7 +195,15 @@ struct FMHAConfig { o_batch_stride, o_head_stride, o_row_stride, reinterpret_cast(args.softmax_lse), lse_stride_head, - lse_stride_batch}, + lse_stride_batch, + args.cache_seqlens, + args.cache_batch_idx, + args.cache_leftpad, + reinterpret_cast(args.knew), + args.knew_batch_stride, args.knew_head_stride, args.knew_row_stride, + reinterpret_cast(args.vnew), + args.vnew_batch_stride, args.vnew_head_stride, args.vnew_row_stride, + args.seqlen_knew}, {static_cast(args.sm_scale), args.window_size_left, args.window_size_right, args.p_dropout, diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp index 6d004e1b..ce3f857e 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp @@ -39,6 +39,22 @@ struct fmha_fwd_args_t { void* s_dmask = nullptr; // Output: attention matrix with dropout sign-bit encoding int seqlen_q_rounded = 0; // Q sequence length rounded up (stride for s_dmask) int seqlen_k_rounded = 0; // K sequence length rounded up (stride for s_dmask) + + // KV Cache fields (optional, set to nullptr for non-kvcache paths) + int* cache_seqlens = nullptr; // (batch_size,) per-batch effective KV length + int* cache_batch_idx = nullptr; // (batch_size,) indices into KV cache batch dim + int* cache_leftpad = nullptr; // (batch_size,) left padding per batch in KV cache + + // Fused KV cache append: new K/V to scatter into cache inside the kernel + void* knew = nullptr; + void* vnew = nullptr; + int seqlen_knew = 0; + int64_t knew_batch_stride = 0; + int64_t knew_head_stride = 0; + int64_t knew_row_stride = 0; + int64_t vnew_batch_stride = 0; + int64_t vnew_head_stride = 0; + int64_t vnew_row_stride = 0; }; enum class CutlassType { diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp index d594c7de..9a8f6018 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp @@ -98,6 +98,20 @@ class XeFMHAFwdKernelXe2 { float* pLSE; int lse_stride_head; int lse_stride_batch; + // KV Cache: per-batch effective KV length (nullptr for non-kvcache paths) + int* cache_seqlens = nullptr; + int* cache_batch_idx = nullptr; + int* cache_leftpad = nullptr; + // Fused KV cache append + const ElementK* Knew = nullptr; + int64_t knew_batch_stride = 0; + int64_t knew_head_stride = 0; + int64_t knew_row_stride = 0; + const ElementV* Vnew = nullptr; + int64_t vnew_batch_stride = 0; + int64_t vnew_head_stride = 0; + int64_t vnew_row_stride = 0; + int seqlen_knew = 0; }; using KernelParams = KernelArguments; @@ -211,11 +225,52 @@ class XeFMHAFwdKernelXe2 { if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; - auto full_tile_offset = seq_len_kv - seq_len_qo; + // KV Cache: override seq_len_kv with per-batch effective length + int effective_seq_kv = seq_len_kv; + int leftpad_k = 0; + if (p.cache_seqlens) { + int bidx = p.cache_batch_idx ? p.cache_batch_idx[idx_b] : idx_b; + int orig_cache_seqlens = p.cache_seqlens[bidx]; + if (p.cache_leftpad) { + leftpad_k = p.cache_leftpad[bidx]; + } + + // Fused cache update: copy knew/vnew into kcache/vcache + if (p.Knew != nullptr && p.seqlen_knew > 0) { + constexpr int num_threads = SGPerWG::value * cute::intel::sg_size; + auto* k_dst = const_cast(p.K) + + bidx * p.k_batch_stride + head * p.k_head_stride + + static_cast(orig_cache_seqlens) * p.k_row_stride; + auto* k_src = p.Knew + + idx_b * p.knew_batch_stride + head * p.knew_head_stride; + for (int si = 0; si < p.seqlen_knew; si++) { + for (int d = thr_id; d < s.head_size_qk; d += num_threads) { + k_dst[si * p.k_row_stride + d] = k_src[si * p.knew_row_stride + d]; + } + } + auto* v_dst = const_cast(p.V) + + bidx * p.v_batch_stride + head * p.v_head_stride + + static_cast(orig_cache_seqlens) * p.v_row_stride; + auto* v_src = p.Vnew + + idx_b * p.vnew_batch_stride + head * p.vnew_head_stride; + for (int si = 0; si < p.seqlen_knew; si++) { + for (int d = thr_id; d < s.head_size_vo; d += num_threads) { + v_dst[si * p.v_row_stride + d] = v_src[si * p.vnew_row_stride + d]; + } + } + sycl::group_barrier(get_work_group<3>()); + effective_seq_kv = (orig_cache_seqlens + p.seqlen_knew) - leftpad_k; + } else { + effective_seq_kv = orig_cache_seqlens - leftpad_k; + } + } + if (effective_seq_kv <= 0) continue; + + auto full_tile_offset = effective_seq_kv - seq_len_qo; int seq_coord = cute::min( seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); int last_seq_coord = seq_coord + q_sg_tile - 1; - int first_non_masked_sequence = seq_len_qo - seq_len_kv; + int first_non_masked_sequence = seq_len_qo - effective_seq_kv; // Causal-only early-exit: skip SGs that are fully masked. With // LocalMask we can't easily do this here, so let the loop body mask. @@ -228,15 +283,15 @@ class XeFMHAFwdKernelXe2 { int seq_len; if constexpr (CausalMask && LocalMask) { seq_len = cute::min( - seq_len_kv, + effective_seq_kv, full_tile_offset + seq_coord + q_sg_tile + params.mainloop.local.local_right); } else if constexpr (CausalMask) { seq_len = calculate_longest_non_masked_length( - seq_len_kv, seq_len_qo, last_seq_coord, + effective_seq_kv, seq_len_qo, last_seq_coord, first_non_masked_sequence); } else { - seq_len = seq_len_kv; + seq_len = effective_seq_kv; } if (seq_len < 0) seq_len = 0; @@ -272,12 +327,17 @@ class XeFMHAFwdKernelXe2 { } else { total_seqlen_kv = seq_len_kv; } + // When leftpad is applied, the K/V base pointer is shifted forward. + // Reduce the surface height so it accurately describes the data + // reachable from the shifted base, preventing 2D block loads from + // extending past the per-batch allocation. + int kv_surface_len = total_seqlen_kv - leftpad_k; auto shape_Q = make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); auto shape_K = make_shape( - total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + kv_surface_len, s.head_size_qk, s.num_heads_kv, batch_dim); auto shape_V = make_shape( - s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim); + s.head_size_vo, kv_surface_len, s.num_heads_kv, batch_dim); auto shape_O = make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); @@ -303,6 +363,12 @@ class XeFMHAFwdKernelXe2 { auto dcV = const_cast(p.V + offset_v); auto ptrO = p.O + offset_o; + // Offset K/V by leftpad to skip left-padding tokens in the cache + if (leftpad_k > 0) { + dcK += leftpad_k * static_cast(p.k_row_stride); + dcV += leftpad_k * static_cast(p.v_row_stride); + } + auto layout_q = is_var_len ? make_ordered_layout(shape_Q, Step<_2, _0, _1, _3>{}) : make_layout(shape_Q, stride_q); @@ -343,7 +409,7 @@ class XeFMHAFwdKernelXe2 { thr_id, seq_len, seq_len_qo, - seq_len_kv, + effective_seq_kv, idx_b, tile_row_idx, rows_of_maxima, diff --git a/flash-attn2/flash_attn_xpu/src/rotary.hpp b/flash-attn2/flash_attn_xpu/src/rotary.hpp new file mode 100644 index 00000000..bff33875 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/rotary.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include +#include +#include + +/// Apply rotary embedding to Q or K tensor (interleaved mode) +template +struct ApplyRotaryInterleavedKernel { + scalar_t* x; + const scalar_t* cos; + const scalar_t* sin; + const int* seqlen_offsets; + int batch_size, seqlen, num_heads, head_dim, rotary_dim, cos_sin_stride; + + void operator()(sycl::nd_item<1> item) const { + int idx = item.get_global_id(0); + int total_elements = batch_size * seqlen * num_heads * (rotary_dim / 2); + if (idx >= total_elements) return; + int half_rotary = rotary_dim / 2; + int pair_idx = idx % half_rotary; + int temp = idx / half_rotary; + int head_idx = temp % num_heads; + temp = temp / num_heads; + int seq_idx = temp % seqlen; + int batch_idx = temp / seqlen; + int pos = (seqlen_offsets != nullptr) ? seqlen_offsets[batch_idx] + seq_idx : seq_idx; + float c = static_cast(cos[pos * cos_sin_stride + pair_idx]); + float s = static_cast(sin[pos * cos_sin_stride + pair_idx]); + int base_offset = ((batch_idx * seqlen + seq_idx) * num_heads + head_idx) * head_dim; + int x0_idx = base_offset + pair_idx * 2; + int x1_idx = base_offset + pair_idx * 2 + 1; + float x0 = static_cast(x[x0_idx]); + float x1 = static_cast(x[x1_idx]); + x[x0_idx] = static_cast(x0 * c - x1 * s); + x[x1_idx] = static_cast(x0 * s + x1 * c); + } +}; + +/// Apply rotary embedding (non-interleaved / GPT-NeoX style) +template +struct ApplyRotaryContiguousKernel { + scalar_t* x; + const scalar_t* cos; + const scalar_t* sin; + const int* seqlen_offsets; + int batch_size, seqlen, num_heads, head_dim, rotary_dim, cos_sin_stride; + + void operator()(sycl::nd_item<1> item) const { + int idx = item.get_global_id(0); + int total_elements = batch_size * seqlen * num_heads * (rotary_dim / 2); + if (idx >= total_elements) return; + int half_rotary = rotary_dim / 2; + int pair_idx = idx % half_rotary; + int temp = idx / half_rotary; + int head_idx = temp % num_heads; + temp = temp / num_heads; + int seq_idx = temp % seqlen; + int batch_idx = temp / seqlen; + int pos = (seqlen_offsets != nullptr) ? seqlen_offsets[batch_idx] + seq_idx : seq_idx; + float c = static_cast(cos[pos * cos_sin_stride + pair_idx]); + float s = static_cast(sin[pos * cos_sin_stride + pair_idx]); + int base_offset = ((batch_idx * seqlen + seq_idx) * num_heads + head_idx) * head_dim; + int x0_idx = base_offset + pair_idx; + int x1_idx = base_offset + pair_idx + half_rotary; + float x0 = static_cast(x[x0_idx]); + float x1 = static_cast(x[x1_idx]); + x[x0_idx] = static_cast(x0 * c - x1 * s); + x[x1_idx] = static_cast(x0 * s + x1 * c); + } +}; + +inline void apply_rotary_emb_inplace( + at::Tensor& x, + const at::Tensor& cos, + const at::Tensor& sin, + const std::optional& seqlen_offsets, + bool interleaved) { + auto batch_size = x.size(0); + auto seqlen = x.size(1); + auto num_heads = x.size(2); + auto head_dim = x.size(3); + auto rotary_dim = cos.size(1) * 2; + TORCH_CHECK(rotary_dim <= head_dim, "rotary_dim must be <= head_dim"); + auto queue = c10::xpu::getCurrentXPUStream().queue(); + int total_pairs = batch_size * seqlen * num_heads * (rotary_dim / 2); + int wg_size = 256; + int num_groups = (total_pairs + wg_size - 1) / wg_size; + if (interleaved) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, x.scalar_type(), "apply_rotary_interleaved", [&] { + const int* offset_ptr = seqlen_offsets.has_value() + ? seqlen_offsets->data_ptr() : nullptr; + ApplyRotaryInterleavedKernel kernel{ + x.data_ptr(), cos.data_ptr(), + sin.data_ptr(), offset_ptr, + (int)batch_size, (int)seqlen, (int)num_heads, + (int)head_dim, (int)rotary_dim, (int)cos.size(1)}; + queue.submit([&](sycl::handler& h) { + h.parallel_for(sycl::nd_range<1>(num_groups * wg_size, wg_size), kernel); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, at::kHalf, x.scalar_type(), "apply_rotary_contiguous", [&] { + const int* offset_ptr = seqlen_offsets.has_value() + ? seqlen_offsets->data_ptr() : nullptr; + ApplyRotaryContiguousKernel kernel{ + x.data_ptr(), cos.data_ptr(), + sin.data_ptr(), offset_ptr, + (int)batch_size, (int)seqlen, (int)num_heads, + (int)head_dim, (int)rotary_dim, (int)cos.size(1)}; + queue.submit([&](sycl::handler& h) { + h.parallel_for(sycl::nd_range<1>(num_groups * wg_size, wg_size), kernel); + }); + }); + } +} diff --git a/flash-attn2/tests/test_flash_attn.py b/flash-attn2/tests/test_flash_attn.py index 5b138722..6fc65305 100644 --- a/flash-attn2/tests/test_flash_attn.py +++ b/flash-attn2/tests/test_flash_attn.py @@ -2026,7 +2026,8 @@ def test_flash_attn_kvcache( if device == "cpu": pytest.skip("kvcache not supported on CPU") if device == "xpu": - pytest.skip("kvcache not supported on xpu currently") + if alibi: + pytest.skip("alibi not supported on xpu currently") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: diff --git a/flash-attn2/torch-ext/torch_binding.cpp b/flash-attn2/torch-ext/torch_binding.cpp index 37403521..b277ea46 100644 --- a/flash-attn2/torch-ext/torch_binding.cpp +++ b/flash-attn2/torch-ext/torch_binding.cpp @@ -144,6 +144,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int num_splits) -> Tensor[]"); #if defined(CUDA_KERNEL) ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache); +#elif defined(XPU_KERNEL) + ops.impl("fwd_kvcache", torch::kXPU, &mha_fwd_kvcache); #endif } From 5217798dcc4ff125bec72918e5e4c9bf2b28752a Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Tue, 28 Apr 2026 08:06:41 +0000 Subject: [PATCH 3/8] Roll back the prefetch operation to fix the paged kvcache bugs --- .../src/collective/fmha_fwd_mainloop_xe2.hpp | 98 +++++++++++-------- 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp index 4cff3697..75c4e359 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp @@ -288,10 +288,13 @@ struct FMHAFwdMainloopXe2< auto prefetch_q = make_block_2d_prefetch(copy_q); auto prefetch_k = make_block_2d_prefetch(copy_k); auto prefetch_v = make_block_2d_prefetch(copy_v); + auto prefetch_v_paged = + make_block_2d_prefetch(tile_shape_v, V_2D); auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV_split); + auto pVgV_paged = prefetch_v_paged.get_slice(thr_id).partition_S(gV); // PagedKV: translate logical K index to physical page-tile index. int tiles_per_page = 0; @@ -311,20 +314,27 @@ struct FMHAFwdMainloopXe2< for (int D = 0; D < size<3>(pQgQ); D++) { prefetch(prefetch_q, pQgQ(_, _, _, D)); } - int prefetch_k_stages = (total_blk < Stages ? total_blk : Stages); - for (int D = 0; D < size<4>(pKgK); D++) { + if constexpr (PagedKV) { CUTLASS_PRAGMA_UNROLL - for (int K = blk_k0; K < blk_k0 + prefetch_k_stages; K++) { - int pk; - if constexpr (PagedKV) { - int ploc = K * get<1>(TileShapeQK{}) / params.paged.page_size; - pk = params.paged.ptr_page_table[b_offset + ploc] * - tiles_per_page + - K % tiles_per_page; - } else { - pk = K; + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, next_page_idx, D)); + } + } else { + int prefetch_k_stages = (total_blk < Stages ? total_blk : Stages); + for (int D = 0; D < size<4>(pKgK); D++) { + CUTLASS_PRAGMA_UNROLL + for (int K = blk_k0; K < blk_k0 + prefetch_k_stages; K++) { + int pk; + if constexpr (PagedKV) { + int ploc = K * get<1>(TileShapeQK{}) / params.paged.page_size; + pk = params.paged.ptr_page_table[b_offset + ploc] * + tiles_per_page + + K % tiles_per_page; + } else { + pk = K; + } + prefetch(prefetch_k, pKgK(_, _, _, pk, D)); } - prefetch(prefetch_k, pKgK(_, _, _, pk, D)); } } clear(tArA); @@ -361,9 +371,11 @@ struct FMHAFwdMainloopXe2< auto tVgV_cache = PagedKV ? tVgV(_, _, _, _, page_idx) : tVgV(_, _, _, _, K); - // Non-paged: prefetch V before GEMM1 to overlap with computation. - // Paged: prefetch V after GEMM1 to avoid BMG hardware hang. - if constexpr (!PagedKV) { + // Paged path uses the old whole-V-tile prefetch pattern; split-V + // prefetch-after-GEMM can hang on BMG for some paged cases. + if constexpr (PagedKV) { + prefetch(prefetch_v_paged, pVgV_paged(_, _, _, page_idx)); + } else if constexpr (!PagedKV) { CUTLASS_PRAGMA_UNROLL for (int VV = 0; VV < VTiles; VV++) { prefetch(prefetch_v, pVgV(_, _, _, VV, K)); @@ -381,13 +393,6 @@ struct FMHAFwdMainloopXe2< cute::gemm(mma_qk, tSrQ, tSrK, tSrS); } - if constexpr (PagedKV) { - CUTLASS_PRAGMA_UNROLL - for (int VV = 0; VV < VTiles; VV++) { - prefetch(prefetch_v, pVgV(_, _, _, VV, page_idx)); - } - } - auto cS_thread = thr_mma_qk.partition_C(gP_all(_, _, K)); if (check_remainder_k && K == blk_k1 - 1) { @@ -522,28 +527,35 @@ struct FMHAFwdMainloopXe2< cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV)); } - int K_next = K + Stages; - if (K_next < blk_k1) { - if constexpr (PagedKV) { - int next_page_local_idx = - K_next * get<1>(TileShapeQK{}) / params.paged.page_size; - int pk_next; - if (next_page_local_idx < params.paged.max_pages_per_seq) { - pk_next = - params.paged.ptr_page_table[b_offset + next_page_local_idx] * - tiles_per_page + - K_next % tiles_per_page; + if constexpr (PagedKV) { + CUTLASS_PRAGMA_UNROLL + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, next_page_idx, D)); + } + } else { + int K_next = K + Stages; + if (K_next < blk_k1) { + if constexpr (PagedKV) { + int next_page_local_idx = + K_next * get<1>(TileShapeQK{}) / params.paged.page_size; + int pk_next; + if (next_page_local_idx < params.paged.max_pages_per_seq) { + pk_next = + params.paged.ptr_page_table[b_offset + next_page_local_idx] * + tiles_per_page + + K_next % tiles_per_page; + } else { + pk_next = params.paged.max_pages_per_seq * tiles_per_page - 1; + } + CUTLASS_PRAGMA_UNROLL + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, pk_next, D)); + } } else { - pk_next = params.paged.max_pages_per_seq * tiles_per_page - 1; - } - CUTLASS_PRAGMA_UNROLL - for (int D = 0; D < size<4>(pKgK); D++) { - prefetch(prefetch_k, pKgK(_, _, _, pk_next, D)); - } - } else { - CUTLASS_PRAGMA_UNROLL - for (int D = 0; D < size<4>(pKgK); D++) { - prefetch(prefetch_k, pKgK(_, _, _, K_next, D)); + CUTLASS_PRAGMA_UNROLL + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, K_next, D)); + } } } } From 468dd6d5f1dcce37e1d532389d80493068a61574 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Wed, 6 May 2026 01:26:15 +0000 Subject: [PATCH 4/8] refine --- flash-attn2/build.toml | 16 +++ flash-attn2/flash_attn_xpu/flash_api.cpp | 52 +++++----- .../src/create_instantiation_files.sh | 33 ++++++- .../flash_fwd_hdim128_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim128_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim160_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim160_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim192_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim192_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim256_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim256_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim32_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim32_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim512_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim512_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim64_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim64_kvcache_paged_fp16.cpp | 20 ++++ .../flash_fwd_hdim96_kvcache_paged_bf16.cpp | 20 ++++ .../flash_fwd_hdim96_kvcache_paged_fp16.cpp | 20 ++++ flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp | 95 ++++++++++++++++-- flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp | 19 ++++ .../src/kernel/fmha_fwd_kernel_xe2.hpp | 98 ++++++++++++++----- 22 files changed, 570 insertions(+), 63 deletions(-) create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp create mode 100644 flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index b979ad45..3a050c2d 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -211,6 +211,22 @@ src = [ "flash_attn_xpu/src/flash_fwd_hdim256_fix_bf16.cpp", "flash_attn_xpu/src/flash_fwd_hdim512_fix_fp16.cpp", "flash_attn_xpu/src/flash_fwd_hdim512_fix_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp", "flash_attn_xpu/src/fmha_bwd_types.hpp", "flash_attn_xpu/src/fmha_bwd.hpp", "flash_attn_xpu/src/fmha_bwd_impl.hpp", diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index b7268ec5..c84e2e61 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -632,10 +632,22 @@ mha_fwd_kvcache( TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); } - // Write new K/V to cache in-place - // Non-paged without padding: fused in kernel (knew/vnew passed to dispatch) - // Paged or needs-padding: API-layer scatter (kernel fusion not applicable) - bool fuse_knew = k_.has_value() && seqlen_new > 0 && !paged_KV && !needs_padding; + // Write new K/V to cache. + // + // Strategy: + // - Always prefer kernel-fused scatter (passes knew/vnew to the kernel, + // which writes them in-place during the prologue). This avoids any + // host sync and works for both contiguous and paged caches. + // - Fall back to API-layer scatter only when fusion is impossible: + // * needs_padding: the cache pad is a separate buffer, so the + // in-kernel writer would write to the padded copy, not the user + // tensor; do the scatter on the user tensor and re-pad. + // * rotary_cos: the rotary application happened on the padded + // buffer; we need to slice off the padding before scattering to + // the user cache. (Kernel-fused scatter copies the padded buffer + // instead, which is wrong.) + bool fuse_knew = k_.has_value() && seqlen_new > 0 + && !needs_padding && !rotary_cos_.has_value(); if (k_.has_value() && seqlen_new > 0 && !fuse_knew) { auto seqlens_cpu = seqlens_k.to(torch::kCPU); auto seqlens_accessor = seqlens_cpu.accessor(); @@ -683,28 +695,8 @@ mha_fwd_kvcache( seqlens_k = seqlens_k + seqlen_new; } - // For paged KV, gather to contiguous format - if (paged_KV) { - int num_pages_needed = (seqlen_k + page_block_size - 1) / page_block_size; - auto block_indices = block_table.index({ - torch::indexing::Slice(), - torch::indexing::Slice(0, num_pages_needed) - }).flatten(); - auto k_gathered = kcache_padded.index_select(0, block_indices.to(torch::kLong)); - auto v_gathered = vcache_padded.index_select(0, block_indices.to(torch::kLong)); - k_gathered = k_gathered.view({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded}); - v_gathered = v_gathered.view({batch_size, num_pages_needed, page_block_size, num_heads_k, head_size_padded}); - k_gathered = k_gathered.view({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded}); - v_gathered = v_gathered.view({batch_size, num_pages_needed * page_block_size, num_heads_k, head_size_padded}); - kcache_padded = k_gathered.index({ - torch::indexing::Slice(), torch::indexing::Slice(0, seqlen_k) - }).contiguous(); - vcache_padded = v_gathered.index({ - torch::indexing::Slice(), torch::indexing::Slice(0, seqlen_k) - }).contiguous(); - } - - // Dispatch to kernel + // Dispatch to kernel. Paged caches are now passed natively (block_table + // routed straight through to the kernel, no host gather). auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue(); const bool is_local = (window_size_left >= 0); @@ -718,19 +710,25 @@ mha_fwd_kvcache( leftpad_k_opt = leftpad_k; } - // For non-paged path with new KV, pass knew/vnew for fused scatter in kernel + // For paths where new KV is appended in-kernel, pass knew/vnew through. std::optional knew_opt, vnew_opt; if (fuse_knew) { knew_opt = k_padded; vnew_opt = v_padded; } + std::optional block_table_opt; + if (paged_KV) { + block_table_opt = block_table; + } + cutlass_fmha_fwd_kvcache_impl( queue, q_padded, kcache_padded, vcache_padded, out, softmax_lse, seqlens_k, cache_batch_idx_opt, leftpad_k_opt, knew_opt, vnew_opt, + block_table_opt, seqlen_k, softmax_scale, window_size_left, window_size_right, is_causal, is_local); diff --git a/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh b/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh index 90036e3b..c4b76554 100755 --- a/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh +++ b/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh @@ -88,8 +88,39 @@ ENDFILE done done +echo "" +echo "Creating kvcache-paged instantiation files (split by dtype)..." +for hdim in "${HDIMS[@]}"; do + for dtype in fp16 bf16; do + cat > flash_fwd_hdim${hdim}_kvcache_paged_${dtype}.cpp << ENDFILE +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=${dtype} +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_${dtype}< + prefill_policy_head${hdim}, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_${dtype}< + decode_paged_policy_head${hdim}, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); +ENDFILE + echo " Created flash_fwd_hdim${hdim}_kvcache_paged_${dtype}.cpp" + done +done + echo "" echo "✓ All instantiation files created successfully!" echo " - $((${#HDIMS[@]} * 2)) varlen files (split by dtype)" echo " - $((${#HDIMS[@]} * 2)) fixed files (split by dtype)" -echo " Total: $((${#HDIMS[@]} * 4)) files" +echo " - $((${#HDIMS[@]} * 2)) kvcache_paged files (split by dtype)" +echo " Total: $((${#HDIMS[@]} * 6)) files" diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..a22eabca --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..54050a99 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..d9b4b30f --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..d68d6f95 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..17a0cf3b --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..9c48d19c --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..fb753042 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..435c1c62 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..ef035b4b --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..2053fc80 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..e8515025 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..95c3cbbf --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..c793d538 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..e6924f25 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..aab66a39 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..049ac07c --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp @@ -0,0 +1,20 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp index f80f4590..bb9e4600 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp @@ -112,6 +112,37 @@ void dispatch_fwd_prefill_by_head(sycl::queue& queue, CutlassType cuType, else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); } +/// Dispatch forward kernel by head_size for the kvcache prefill paged path +/// (non-varlen, seqlen_q > 1, IsPaged=1). +void dispatch_fwd_kvcache_prefill_paged_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size) { + if (head_size <= 32) policy_dispatch(queue, cuType, args); + else if (head_size <= 64) policy_dispatch(queue, cuType, args); + else if (head_size <= 96) policy_dispatch(queue, cuType, args); + else if (head_size <= 128) policy_dispatch(queue, cuType, args); + else if (head_size <= 160) policy_dispatch(queue, cuType, args); + else if (head_size <= 192) policy_dispatch(queue, cuType, args); + else if (head_size <= 256) policy_dispatch(queue, cuType, args); + else if (head_size == 512) policy_dispatch(queue, cuType, args); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + +/// Dispatch forward kernel by head_size for the kvcache decode paged path +/// (non-varlen, seqlen_q == 1, IsPaged=1). Uses decode_paged_policy with a +/// smaller K-tile so block_size=64 multiples don't overshoot a page. +void dispatch_fwd_kvcache_decode_paged_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size) { + if (head_size <= 32) policy_dispatch(queue, cuType, args); + else if (head_size <= 64) policy_dispatch(queue, cuType, args); + else if (head_size <= 96) policy_dispatch(queue, cuType, args); + else if (head_size <= 128) policy_dispatch(queue, cuType, args); + else if (head_size <= 160) policy_dispatch(queue, cuType, args); + else if (head_size <= 192) policy_dispatch(queue, cuType, args); + else if (head_size <= 256) policy_dispatch(queue, cuType, args); + else if (head_size == 512) policy_dispatch(queue, cuType, args); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + /// Clamp window sizes for local attention and fold causal into local when both are set. void normalize_window_params(int& window_size_left, int& window_size_right, bool& is_causal, bool is_local, int max_seqlen_k) { @@ -319,21 +350,46 @@ void cutlass_fmha_fwd_kvcache_impl( const std::optional& cache_leftpad, const std::optional& knew, const std::optional& vnew, + const std::optional& block_table, + int max_seqlen_k_paged, float sm_scale, int window_size_left, int window_size_right, bool is_causal, bool is_local) { + const bool is_paged = block_table.has_value() && block_table->defined(); + const int batch_size = query.size(0); const int max_seqlen_q = query.size(1); const int num_heads_q = query.size(2); const int head_size = query.size(3); - const int max_seqlen_k = kcache.size(1); - const int num_heads_kv = kcache.size(2); + int max_seqlen_k; + int num_heads_kv; + int num_blocks = 0; + int block_size = 0; + int max_blocks_per_seq = 0; + if (is_paged) { + // Paged layout: kcache is (num_blocks, page_block_size, num_heads_kv, head_size) + num_blocks = kcache.size(0); + block_size = kcache.size(1); + num_heads_kv = kcache.size(2); + max_blocks_per_seq = block_table->size(1); + max_seqlen_k = max_seqlen_k_paged; + TORCH_CHECK(num_blocks * block_size >= max_seqlen_k, + "Paged KV pool too small for max_seqlen_k: ", + num_blocks * block_size, " < ", max_seqlen_k); + TORCH_CHECK(!cache_batch_idx.has_value(), + "Paged KVcache does not support cache_batch_idx"); + } else { + max_seqlen_k = kcache.size(1); + num_heads_kv = kcache.size(2); + } const int total_seqlen_q = batch_size * max_seqlen_q; - const int total_seqlen_k = batch_size * max_seqlen_k; + const int total_seqlen_k = is_paged + ? num_blocks * block_size + : batch_size * max_seqlen_k; normalize_window_params(window_size_left, window_size_right, is_causal, is_local, max_seqlen_k); @@ -344,7 +400,7 @@ void cutlass_fmha_fwd_kvcache_impl( vcache.data_ptr(), out.data_ptr(), softmax_lse.data_ptr(), - nullptr, // block_table + is_paged ? block_table->data_ptr() : nullptr, nullptr, // cu_seqlens_q nullptr, // cu_seqlens_k max_seqlen_q, @@ -356,12 +412,12 @@ void cutlass_fmha_fwd_kvcache_impl( num_heads_q, num_heads_kv, head_size, - 0, // max_blocks_per_seq - 0, // block_size + max_blocks_per_seq, + block_size, window_size_left, window_size_right, false, // is_varlen - false, // is_paged + is_paged, is_causal, is_local, 0.0f, 0, 0, nullptr, nullptr, 0, 0, // dropout & s_dmask defaults @@ -385,9 +441,28 @@ void cutlass_fmha_fwd_kvcache_impl( const CutlassType cuType = aten_to_Cutlass_dtype(query); const int h = args.head_size; - if (max_seqlen_q == 1) { - dispatch_fwd_decode_by_head(queue, cuType, args, h); + if (is_paged) { + // Paged dispatch requires the K-tile to evenly divide block_size, + // because each tile is loaded from a single page. + const int k_tile_n = paged_k_tile_size_n(h, max_seqlen_q); + TORCH_CHECK(k_tile_n > 0, + "Unsupported head_size for paged FA2 kvcache: ", h); + TORCH_CHECK(is_supported_paged_block_size(block_size), + "Unsupported paged KV block_size=", block_size, + ". Supported values are positive multiples of 64."); + TORCH_CHECK(block_size % k_tile_n == 0, + "Paged KV block_size must be a multiple of the kernel K tile. " + "Got block_size=", block_size, ", K tile=", k_tile_n); + if (max_seqlen_q == 1) { + dispatch_fwd_kvcache_decode_paged_by_head(queue, cuType, args, h); + } else { + dispatch_fwd_kvcache_prefill_paged_by_head(queue, cuType, args, h); + } } else { - dispatch_fwd_prefill_by_head(queue, cuType, args, h); + if (max_seqlen_q == 1) { + dispatch_fwd_decode_by_head(queue, cuType, args, h); + } else { + dispatch_fwd_prefill_by_head(queue, cuType, args, h); + } } } diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp index 9cc9fd0e..2179e7a5 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp @@ -119,6 +119,23 @@ EXTERN_DISPATCH_FIX(256) EXTERN_DISPATCH_FIX(512) #undef EXTERN_DISPATCH_FIX +// KVCache + paged extern declarations (IsVarLen=0, IsPaged=1) +#define EXTERN_DISPATCH_KVCACHE_PAGED(HDIM) \ + extern template void policy_dispatch_fp16(sycl::queue&, const fmha_fwd_args_t&); \ + extern template void policy_dispatch_fp16(sycl::queue&, const fmha_fwd_args_t&); \ + extern template void policy_dispatch_bf16(sycl::queue&, const fmha_fwd_args_t&); \ + extern template void policy_dispatch_bf16(sycl::queue&, const fmha_fwd_args_t&); + +EXTERN_DISPATCH_KVCACHE_PAGED(32) +EXTERN_DISPATCH_KVCACHE_PAGED(64) +EXTERN_DISPATCH_KVCACHE_PAGED(96) +EXTERN_DISPATCH_KVCACHE_PAGED(128) +EXTERN_DISPATCH_KVCACHE_PAGED(160) +EXTERN_DISPATCH_KVCACHE_PAGED(192) +EXTERN_DISPATCH_KVCACHE_PAGED(256) +EXTERN_DISPATCH_KVCACHE_PAGED(512) +#undef EXTERN_DISPATCH_KVCACHE_PAGED + void cutlass_fmha_fwd_varlen_impl( sycl::queue& queue, const at::Tensor& query, @@ -171,6 +188,8 @@ void cutlass_fmha_fwd_kvcache_impl( const std::optional& cache_leftpad, const std::optional& knew, const std::optional& vnew, + const std::optional& block_table, + int max_seqlen_k_paged, float sm_scale, int window_size_left, int window_size_right, diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp index 9a8f6018..5928a4a1 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp @@ -228,8 +228,11 @@ class XeFMHAFwdKernelXe2 { // KV Cache: override seq_len_kv with per-batch effective length int effective_seq_kv = seq_len_kv; int leftpad_k = 0; + // bidx maps the logical request to a slot in the physical KV cache. + // - paged path: cache_batch_idx is forbidden, K/V are flat blocks + // - non-paged path: bidx selects the per-batch KV slice + int bidx = (!PagedKV && p.cache_batch_idx) ? p.cache_batch_idx[idx_b] : idx_b; if (p.cache_seqlens) { - int bidx = p.cache_batch_idx ? p.cache_batch_idx[idx_b] : idx_b; int orig_cache_seqlens = p.cache_seqlens[bidx]; if (p.cache_leftpad) { leftpad_k = p.cache_leftpad[bidx]; @@ -238,24 +241,62 @@ class XeFMHAFwdKernelXe2 { // Fused cache update: copy knew/vnew into kcache/vcache if (p.Knew != nullptr && p.seqlen_knew > 0) { constexpr int num_threads = SGPerWG::value * cute::intel::sg_size; - auto* k_dst = const_cast(p.K) - + bidx * p.k_batch_stride + head * p.k_head_stride - + static_cast(orig_cache_seqlens) * p.k_row_stride; auto* k_src = p.Knew + idx_b * p.knew_batch_stride + head * p.knew_head_stride; - for (int si = 0; si < p.seqlen_knew; si++) { - for (int d = thr_id; d < s.head_size_qk; d += num_threads) { - k_dst[si * p.k_row_stride + d] = k_src[si * p.knew_row_stride + d]; - } - } - auto* v_dst = const_cast(p.V) - + bidx * p.v_batch_stride + head * p.v_head_stride - + static_cast(orig_cache_seqlens) * p.v_row_stride; auto* v_src = p.Vnew + idx_b * p.vnew_batch_stride + head * p.vnew_head_stride; - for (int si = 0; si < p.seqlen_knew; si++) { - for (int d = thr_id; d < s.head_size_vo; d += num_threads) { - v_dst[si * p.v_row_stride + d] = v_src[si * p.vnew_row_stride + d]; + if constexpr (PagedKV) { + // Paged scatter: per-token compute (block, page_offset) from + // block_table, then write to the corresponding page slot. + // Each "block" in the paged K/V tensor spans `page_size` rows, + // so the per-block byte stride is `page_size * row_stride`, + // not `k_batch_stride` (which is sized for the *whole* logical + // KV layout, not per-page). + const int page_size = params.mainloop.paged.page_size; + const int max_pages_per_seq = + params.mainloop.paged.max_pages_per_seq; + const int* page_table = params.mainloop.paged.ptr_page_table + + idx_b * max_pages_per_seq; + const int64_t k_block_stride = + static_cast(page_size) * p.k_row_stride; + const int64_t v_block_stride = + static_cast(page_size) * p.v_row_stride; + for (int si = 0; si < p.seqlen_knew; si++) { + int global_pos = orig_cache_seqlens + si; + int page_idx = global_pos / page_size; + int page_off = global_pos % page_size; + int block = page_table[page_idx]; + auto* k_dst = const_cast(p.K) + + static_cast(block) * k_block_stride + + head * p.k_head_stride + + static_cast(page_off) * p.k_row_stride; + auto* v_dst = const_cast(p.V) + + static_cast(block) * v_block_stride + + head * p.v_head_stride + + static_cast(page_off) * p.v_row_stride; + for (int d = thr_id; d < s.head_size_qk; d += num_threads) { + k_dst[d] = k_src[si * p.knew_row_stride + d]; + } + for (int d = thr_id; d < s.head_size_vo; d += num_threads) { + v_dst[d] = v_src[si * p.vnew_row_stride + d]; + } + } + } else { + auto* k_dst = const_cast(p.K) + + bidx * p.k_batch_stride + head * p.k_head_stride + + static_cast(orig_cache_seqlens) * p.k_row_stride; + auto* v_dst = const_cast(p.V) + + bidx * p.v_batch_stride + head * p.v_head_stride + + static_cast(orig_cache_seqlens) * p.v_row_stride; + for (int si = 0; si < p.seqlen_knew; si++) { + for (int d = thr_id; d < s.head_size_qk; d += num_threads) { + k_dst[si * p.k_row_stride + d] = k_src[si * p.knew_row_stride + d]; + } + } + for (int si = 0; si < p.seqlen_knew; si++) { + for (int d = thr_id; d < s.head_size_vo; d += num_threads) { + v_dst[si * p.v_row_stride + d] = v_src[si * p.vnew_row_stride + d]; + } } } sycl::group_barrier(get_work_group<3>()); @@ -320,7 +361,10 @@ class XeFMHAFwdKernelXe2 { offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; } - auto batch_dim = is_var_len ? 1 : s.batch; + auto batch_dim_q = is_var_len ? 1 : s.batch; + // Paged KV is laid out as (num_blocks * page_size, head, num_heads_kv) + // with no batch dimension; treat it like the varlen K/V layout. + auto batch_dim_kv = (is_var_len || PagedKV) ? 1 : s.batch; int total_seqlen_kv; if constexpr (PagedKV) { total_seqlen_kv = params.mainloop.paged.total_seqlen_kv; @@ -333,13 +377,13 @@ class XeFMHAFwdKernelXe2 { // extending past the per-batch allocation. int kv_surface_len = total_seqlen_kv - leftpad_k; auto shape_Q = - make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); + make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim_q); auto shape_K = make_shape( - kv_surface_len, s.head_size_qk, s.num_heads_kv, batch_dim); + kv_surface_len, s.head_size_qk, s.num_heads_kv, batch_dim_kv); auto shape_V = make_shape( - s.head_size_vo, kv_surface_len, s.num_heads_kv, batch_dim); + s.head_size_vo, kv_surface_len, s.num_heads_kv, batch_dim_kv); auto shape_O = - make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); + make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim_q); auto stride_q = cutlass::make_stride( static_cast(p.q_row_stride), Int<1>{}, @@ -393,12 +437,16 @@ class XeFMHAFwdKernelXe2 { int rows_of_maxima = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); - int l_coord = is_var_len ? 0 : idx_b; + // For non-paged KV, reuse cache_batch_idx remap (bidx) so that the + // KV slice matches the per-request seqlen. Q/O always use idx_b. + int l_coord_q = is_var_len ? 0 : idx_b; + int l_coord_kv = is_var_len ? 0 + : (PagedKV ? 0 : bidx); CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); mainloop( - Q(_, _, head_q, l_coord), - K(_, _, head, l_coord), - V(_, _, head, l_coord), + Q(_, _, head_q, l_coord_q), + K(_, _, head, l_coord_kv), + V(_, _, head, l_coord_kv), tArA, tA_max, tA_sum, @@ -433,7 +481,7 @@ class XeFMHAFwdKernelXe2 { tile_row_idx, rows_of_maxima); epilogue( - O(_, _, head_q, l_coord), + O(_, _, head_q, l_coord_q), tArA, tA_max, tA_sum, From 7f7267dd50028546c414fb2f90797f033a9ca72b Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Sat, 9 May 2026 01:30:53 +0000 Subject: [PATCH 5/8] Add fused rotary support for XPU kvcache --- flash-attn2/flash_attn_xpu/flash_api.cpp | 64 ++++++------ .../src/collective/fmha_fwd_common.hpp | 42 ++++++++ .../src/collective/fmha_fwd_mainloop_xe2.hpp | 64 +++++++++++- .../src/create_instantiation_files.sh | 29 ++++++ .../flash_fwd_hdim128_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim128_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim160_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim160_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim192_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim192_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim256_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim256_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim32_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim32_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim512_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim512_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim64_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim64_kvcache_paged_fp16.cpp | 29 ++++++ .../flash_fwd_hdim96_kvcache_paged_bf16.cpp | 29 ++++++ .../flash_fwd_hdim96_kvcache_paged_fp16.cpp | 29 ++++++ flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp | 99 ++++++++++++++----- flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp | 17 +++- .../flash_attn_xpu/src/fmha_fwd_impl.hpp | 25 +++-- .../flash_attn_xpu/src/fmha_fwd_types.hpp | 6 ++ .../src/kernel/fmha_fwd_kernel_xe2.hpp | 47 ++++++++- 25 files changed, 783 insertions(+), 74 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index c84e2e61..1ef3cfd8 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -592,29 +592,19 @@ mha_fwd_kvcache( CHECK_DEVICE(seqlens_k); } - // Handle rotary embedding (pre-process in-place before kernel) - if (rotary_cos_.has_value()) { + at::Tensor rotary_cos, rotary_sin; + int rotary_dim = 0; + const bool has_rotary = rotary_cos_.has_value(); + if (has_rotary) { TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key/value must also be provided"); - auto rotary_cos = rotary_cos_.value(); - auto rotary_sin = rotary_sin_.value(); + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + rotary_cos = ensure_contiguous(rotary_cos_.value()); + rotary_sin = ensure_contiguous(rotary_sin_.value()); CHECK_DEVICE(rotary_cos); CHECK_DEVICE(rotary_sin); - int rotary_dim = rotary_cos.size(1) * 2; + rotary_dim = rotary_cos.size(1) * 2; TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim"); TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); TORCH_CHECK(rotary_cos.scalar_type() == q_dtype && rotary_sin.scalar_type() == q_dtype); - - std::optional seqlen_offsets_opt; - if (seqlens_k_.has_value()) { seqlen_offsets_opt = seqlens_k; } - - bool is_local = (window_size_left >= 0); - if (is_causal || is_local) { - apply_rotary_emb_inplace(q_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); - } else { - auto q_shape = q_padded.sizes(); - auto q_reshaped = q_padded.view({q_shape[0], 1, q_shape[1] * q_shape[2], q_shape[3]}); - apply_rotary_emb_inplace(q_reshaped, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); - } - apply_rotary_emb_inplace(k_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); } at::Tensor cache_batch_idx; @@ -638,21 +628,30 @@ mha_fwd_kvcache( // - Always prefer kernel-fused scatter (passes knew/vnew to the kernel, // which writes them in-place during the prologue). This avoids any // host sync and works for both contiguous and paged caches. - // - Fall back to API-layer scatter only when fusion is impossible: - // * needs_padding: the cache pad is a separate buffer, so the - // in-kernel writer would write to the padded copy, not the user - // tensor; do the scatter on the user tensor and re-pad. - // * rotary_cos: the rotary application happened on the padded - // buffer; we need to slice off the padding before scattering to - // the user cache. (Kernel-fused scatter copies the padded buffer - // instead, which is wrong.) + // - Fall back to API-layer scatter only when padding is needed: the + // padded cache is a separate buffer, so the in-kernel writer would + // not update the user tensor. bool fuse_knew = k_.has_value() && seqlen_new > 0 - && !needs_padding && !rotary_cos_.has_value(); + && !needs_padding; + if (has_rotary && !fuse_knew) { + std::optional seqlen_offsets_opt; + if (seqlens_k_.has_value()) { seqlen_offsets_opt = seqlens_k; } + + bool is_local = (window_size_left >= 0); + if (is_causal || is_local) { + apply_rotary_emb_inplace(q_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); + } else { + auto q_shape = q_padded.sizes(); + auto q_reshaped = q_padded.view({q_shape[0], 1, q_shape[1] * q_shape[2], q_shape[3]}); + apply_rotary_emb_inplace(q_reshaped, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); + } + apply_rotary_emb_inplace(k_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); + } if (k_.has_value() && seqlen_new > 0 && !fuse_knew) { auto seqlens_cpu = seqlens_k.to(torch::kCPU); auto seqlens_accessor = seqlens_cpu.accessor(); - at::Tensor k_for_cache = rotary_cos_.has_value() + at::Tensor k_for_cache = has_rotary ? k_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}).contiguous() : ensure_contiguous(k_.value()); @@ -722,13 +721,20 @@ mha_fwd_kvcache( block_table_opt = block_table; } + std::optional rotary_cos_opt, rotary_sin_opt; + if (fuse_knew && has_rotary) { + rotary_cos_opt = rotary_cos; + rotary_sin_opt = rotary_sin; + } + cutlass_fmha_fwd_kvcache_impl( queue, q_padded, kcache_padded, vcache_padded, out, softmax_lse, seqlens_k, cache_batch_idx_opt, leftpad_k_opt, knew_opt, vnew_opt, - block_table_opt, seqlen_k, + block_table_opt, rotary_cos_opt, rotary_sin_opt, + fuse_knew ? rotary_dim : 0, is_rotary_interleaved, seqlen_k, softmax_scale, window_size_left, window_size_right, is_causal, is_local); diff --git a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp index 7e8c0ea5..b992343a 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp @@ -22,6 +22,48 @@ namespace cutlass::fmha::collective { using namespace cute; +template +CUTLASS_DEVICE Element apply_rotary_scalar( + Element x, + Element x_pair, + const RotaryElement* cos, + const RotaryElement* sin, + int position, + int dim, + int rotary_dim, + bool interleaved) { + if (rotary_dim == 0 || dim >= rotary_dim) { + return x; + } + + int half_rotary = rotary_dim / 2; + int cos_sin_idx = interleaved ? dim / 2 + : (dim < half_rotary ? dim : dim - half_rotary); + bool is_second = interleaved ? (dim % 2) : (dim >= half_rotary); + + float x_f = static_cast(x); + float x_pair_f = static_cast(x_pair); + float c = static_cast(cos[position * half_rotary + cos_sin_idx]); + float s = static_cast(sin[position * half_rotary + cos_sin_idx]); + float rotated = is_second ? x_pair_f * s + x_f * c + : x_f * c - x_pair_f * s; + return static_cast(rotated); +} + +CUTLASS_DEVICE int rotary_pair_dim( + int dim, + int rotary_dim, + bool interleaved) { + if (dim >= rotary_dim) { + return dim; + } + if (interleaved) { + return dim ^ 1; + } + int half_rotary = rotary_dim / 2; + return dim < half_rotary ? dim + half_rotary : dim - half_rotary; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // // FMHAFwdMainloopTraits: common type aliases derived from TiledMMA / VTiles. diff --git a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp index 75c4e359..8f7a5b1e 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp @@ -47,7 +47,8 @@ template < class TensorV_, class TiledCopyQ_ = void, class TiledCopyK_ = void, - class TiledCopyV_ = void> + class TiledCopyV_ = void, + bool HasRotary_ = false> struct FMHAFwdMainloopXe2 { static_assert( cutlass::detail::dependent_false, @@ -70,7 +71,8 @@ template < class TensorV_, class TiledCopyQ_, class TiledCopyK_, - class TiledCopyV_> + class TiledCopyV_, + bool HasRotary_> struct FMHAFwdMainloopXe2< Xe2, CausalMask_, @@ -85,7 +87,8 @@ struct FMHAFwdMainloopXe2< TensorV_, TiledCopyQ_, TiledCopyK_, - TiledCopyV_> { + TiledCopyV_, + HasRotary_> { // Pull in common type aliases from the shared traits. using Traits = FMHAFwdMainloopTraits< @@ -123,6 +126,7 @@ struct FMHAFwdMainloopXe2< static constexpr bool LocalMask = LocalMask_; static constexpr bool HasDropout = HasDropout_; static constexpr bool PagedKV = PagedKV_; + static constexpr bool HasRotary = HasRotary_; // User-facing arguments struct Arguments { @@ -142,6 +146,10 @@ struct FMHAFwdMainloopXe2< int page_size = 0; int max_pages_per_seq = 0; int total_seqlen_kv = 0; + const typename TensorQ::element_type* rotary_cos = nullptr; + const typename TensorQ::element_type* rotary_sin = nullptr; + int rotary_dim = 0; + bool is_rotary_interleaved = true; }; struct LocalMaskFields { @@ -165,6 +173,14 @@ struct FMHAFwdMainloopXe2< }; struct EmptyPaged {}; + struct RotaryFields { + const typename TensorQ::element_type* rotary_cos = nullptr; + const typename TensorQ::element_type* rotary_sin = nullptr; + int rotary_dim = 0; + bool is_rotary_interleaved = true; + }; + struct EmptyRotary {}; + // Kernel-facing parameters struct Params { ElementS scale; @@ -174,6 +190,8 @@ struct FMHAFwdMainloopXe2< dropout_fields; [[no_unique_address]] conditional_t paged; + [[no_unique_address]] conditional_t + rotary; }; // SLM data @@ -209,6 +227,10 @@ struct FMHAFwdMainloopXe2< p.paged = {args.ptr_page_table, args.page_size, args.max_pages_per_seq, args.total_seqlen_kv}; } + if constexpr (HasRotary) { + p.rotary = {args.rotary_cos, args.rotary_sin, + args.rotary_dim, args.is_rotary_interleaved}; + } return p; } @@ -236,7 +258,9 @@ struct FMHAFwdMainloopXe2< int& tile_row_idx, const int& rows_of_maxima, int head_q, - int num_heads) { + int num_heads, + int q_offset_sg, + int rotary_base) { using namespace sycl::ext::oneapi::this_work_item; auto tile_shape_v = @@ -387,6 +411,38 @@ struct FMHAFwdMainloopXe2< CUTLASS_PRAGMA_UNROLL for (int D = 0; D < size<4>(tKgK); D++) { copy(copy_q, tQgQ(_, _, _, D), tQrQ); + if constexpr (HasRotary) { + if (params.rotary.rotary_dim > 0 && + params.rotary.rotary_cos != nullptr && + params.rotary.rotary_sin != nullptr) { + auto tQrQ_coords = tQrQ.tv_layout(); + int lane_id = static_cast(get_sub_group().get_local_linear_id()); + int q_tile_base = get<0>(blk_qv) * get<0>(TileShapeQK{}) + q_offset_sg; + int dim_tile_base = D * get<2>(TileShapeQK{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tQrQ.size(); ++i) { + auto value_coord = idx2crd( + i, make_shape( + get<1>(shape(tQrQ_coords)), + get<2>(shape(tQrQ_coords)))); + auto coord = tQrQ_coords( + make_coord(lane_id, get<0>(value_coord), get<1>(value_coord))); + int row = q_tile_base + get<0>(coord); + int dim = dim_tile_base + get<1>(coord); + if (row < seq_len_qo && dim < params.rotary.rotary_dim) { + int pair_dim = rotary_pair_dim( + dim, params.rotary.rotary_dim, + params.rotary.is_rotary_interleaved); + int position = rotary_base + ((CausalMask || LocalMask) ? row : 0); + tQrQ(i) = apply_rotary_scalar( + tQrQ(i), Q_2D(row, pair_dim), params.rotary.rotary_cos, + params.rotary.rotary_sin, position, dim, + params.rotary.rotary_dim, + params.rotary.is_rotary_interleaved); + } + } + } + } copy(copy_k, tKgK_cache(_, _, _, D), tKrK); reorder(tQrQ, tSrQ); reorder(tKrK, tSrK); diff --git a/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh b/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh index c4b76554..aebd3c85 100755 --- a/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh +++ b/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh @@ -113,6 +113,35 @@ template void policy_dispatch_${dtype}< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_${dtype}< + prefill_policy_head${hdim}, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_${dtype}< + decode_policy_head${hdim}, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_${dtype}< + prefill_policy_head${hdim}, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_${dtype}< + decode_paged_policy_head${hdim}, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); ENDFILE echo " Created flash_fwd_hdim${hdim}_kvcache_paged_${dtype}.cpp" done diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp index a22eabca..d19f10b4 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head128, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp index 54050a99..3676037a 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head128, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp index d9b4b30f..63fd6b19 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head160, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp index d68d6f95..cb8cda7d 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head160, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp index 17a0cf3b..14346625 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head192, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp index 9c48d19c..a34c90cd 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head192, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp index fb753042..4d9f0dde 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head256, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp index 435c1c62..3e3b8303 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head256, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp index ef035b4b..f5d45308 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head32, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp index 2053fc80..3a09b49f 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head32, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp index e8515025..2a25b6e9 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head512, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp index 95c3cbbf..31b7d308 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head512, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp index c793d538..a55d84e9 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head64, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp index e6924f25..dd221976 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head64, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp index aab66a39..0e270ada 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_bf16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head96, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp index 049ac07c..5cd8032e 100644 --- a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp @@ -18,3 +18,32 @@ template void policy_dispatch_fp16< 0, 1>( sycl::queue& queue, const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head96, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp index bb9e4600..27d81d0e 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp @@ -55,6 +55,19 @@ void dispatch_varlen_decode_paged(sycl::queue& queue, CutlassType cuType, } } +template +void dispatch_kvcache_policy(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, + bool has_rotary) { + if (has_rotary) { + policy_dispatch(queue, cuType, args); + } else { + policy_dispatch(queue, cuType, args); + } +} + /// Dispatch forward kernel by head_size for the varlen prefill path. /// Supported head dimensions are bucketed to the corresponding prefill policies. void dispatch_fwd_varlen_by_head(sycl::queue& queue, CutlassType cuType, @@ -115,15 +128,16 @@ void dispatch_fwd_prefill_by_head(sycl::queue& queue, CutlassType cuType, /// Dispatch forward kernel by head_size for the kvcache prefill paged path /// (non-varlen, seqlen_q > 1, IsPaged=1). void dispatch_fwd_kvcache_prefill_paged_by_head(sycl::queue& queue, CutlassType cuType, - const fmha_fwd_args_t& args, int head_size) { - if (head_size <= 32) policy_dispatch(queue, cuType, args); - else if (head_size <= 64) policy_dispatch(queue, cuType, args); - else if (head_size <= 96) policy_dispatch(queue, cuType, args); - else if (head_size <= 128) policy_dispatch(queue, cuType, args); - else if (head_size <= 160) policy_dispatch(queue, cuType, args); - else if (head_size <= 192) policy_dispatch(queue, cuType, args); - else if (head_size <= 256) policy_dispatch(queue, cuType, args); - else if (head_size == 512) policy_dispatch(queue, cuType, args); + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); } @@ -131,15 +145,44 @@ void dispatch_fwd_kvcache_prefill_paged_by_head(sycl::queue& queue, CutlassType /// (non-varlen, seqlen_q == 1, IsPaged=1). Uses decode_paged_policy with a /// smaller K-tile so block_size=64 multiples don't overshoot a page. void dispatch_fwd_kvcache_decode_paged_by_head(sycl::queue& queue, CutlassType cuType, - const fmha_fwd_args_t& args, int head_size) { - if (head_size <= 32) policy_dispatch(queue, cuType, args); - else if (head_size <= 64) policy_dispatch(queue, cuType, args); - else if (head_size <= 96) policy_dispatch(queue, cuType, args); - else if (head_size <= 128) policy_dispatch(queue, cuType, args); - else if (head_size <= 160) policy_dispatch(queue, cuType, args); - else if (head_size <= 192) policy_dispatch(queue, cuType, args); - else if (head_size <= 256) policy_dispatch(queue, cuType, args); - else if (head_size == 512) policy_dispatch(queue, cuType, args); + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + +void dispatch_fwd_kvcache_decode_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + +void dispatch_fwd_kvcache_prefill_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); } @@ -351,6 +394,10 @@ void cutlass_fmha_fwd_kvcache_impl( const std::optional& knew, const std::optional& vnew, const std::optional& block_table, + const std::optional& rotary_cos, + const std::optional& rotary_sin, + int rotary_dim, + bool is_rotary_interleaved, int max_seqlen_k_paged, float sm_scale, int window_size_left, @@ -436,10 +483,16 @@ void cutlass_fmha_fwd_kvcache_impl( knew.has_value() ? knew->stride(1) : 0, vnew.has_value() ? vnew->stride(0) : 0, vnew.has_value() ? vnew->stride(2) : 0, - vnew.has_value() ? vnew->stride(1) : 0}; + vnew.has_value() ? vnew->stride(1) : 0, + rotary_cos.has_value() ? rotary_cos->data_ptr() : nullptr, + rotary_sin.has_value() ? rotary_sin->data_ptr() : nullptr, + rotary_dim, + is_rotary_interleaved}; const CutlassType cuType = aten_to_Cutlass_dtype(query); const int h = args.head_size; + const bool has_rotary = args.rotary_dim > 0 && args.rotary_cos != nullptr && + args.rotary_sin != nullptr; if (is_paged) { // Paged dispatch requires the K-tile to evenly divide block_size, @@ -454,15 +507,15 @@ void cutlass_fmha_fwd_kvcache_impl( "Paged KV block_size must be a multiple of the kernel K tile. " "Got block_size=", block_size, ", K tile=", k_tile_n); if (max_seqlen_q == 1) { - dispatch_fwd_kvcache_decode_paged_by_head(queue, cuType, args, h); + dispatch_fwd_kvcache_decode_paged_by_head(queue, cuType, args, h, has_rotary); } else { - dispatch_fwd_kvcache_prefill_paged_by_head(queue, cuType, args, h); + dispatch_fwd_kvcache_prefill_paged_by_head(queue, cuType, args, h, has_rotary); } } else { if (max_seqlen_q == 1) { - dispatch_fwd_decode_by_head(queue, cuType, args, h); + dispatch_fwd_kvcache_decode_by_head(queue, cuType, args, h, has_rotary); } else { - dispatch_fwd_prefill_by_head(queue, cuType, args, h); + dispatch_fwd_kvcache_prefill_by_head(queue, cuType, args, h, has_rotary); } } } diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp index 2179e7a5..ed1576f1 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp @@ -44,27 +44,30 @@ struct decode_paged_policy_head256; struct decode_paged_policy_head512; // Dtype-specific dispatch functions (instantiated in per-head TUs) -template +template void policy_dispatch_fp16( sycl::queue& queue, const fmha_fwd_args_t& args); -template +template void policy_dispatch_bf16( sycl::queue& queue, const fmha_fwd_args_t& args); // Combined dispatch (delegates to fp16/bf16 based on cuType) // Defined inline in header so callers (fmha_fwd.cpp) can see the template body. -template +template inline void policy_dispatch( sycl::queue& queue, CutlassType cuType, const fmha_fwd_args_t& args) { if (cuType == CutlassType::half) { - policy_dispatch_fp16(queue, args); + policy_dispatch_fp16(queue, args); } else { - policy_dispatch_bf16(queue, args); + policy_dispatch_bf16(queue, args); } } @@ -189,6 +192,10 @@ void cutlass_fmha_fwd_kvcache_impl( const std::optional& knew, const std::optional& vnew, const std::optional& block_table, + const std::optional& rotary_cos, + const std::optional& rotary_sin, + int rotary_dim, + bool is_rotary_interleaved, int max_seqlen_k_paged, float sm_scale, int window_size_left, diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp index 7f12a1e7..282f0bb7 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp @@ -72,7 +72,8 @@ struct FMHAConfig { bool Local, bool Dropout, bool Paged, - bool VarLen> + bool VarLen, + bool HasRotary> static void run_xe2(sycl::queue& queue, const fmha_fwd_args_t& args) { cutlass::KernelHardwareInfo hw_info; @@ -118,7 +119,8 @@ struct FMHAConfig { TensorV, GmemTiledCopyQ, GmemTiledCopyK, - GmemTiledCopyV>; + GmemTiledCopyV, + HasRotary>; using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogueXe2< CollectiveMainloop, @@ -213,7 +215,11 @@ struct FMHAConfig { static_cast(args.block_table), args.block_size, args.max_blocks_per_seq, - args.total_seqlen_k}, + args.total_seqlen_k, + reinterpret_cast(args.rotary_cos), + reinterpret_cast(args.rotary_sin), + args.rotary_dim, + args.is_rotary_interleaved}, {}, hw_info}; @@ -249,6 +255,7 @@ struct FMHAConfig { template < int IsVarLen, int IsPaged, + bool HasRotary = false, class Scheduler = cutlass::fmha::kernel::XeFHMAIndividualTileScheduler> static void xe2_dispatch(sycl::queue& queue, const fmha_fwd_args_t& args) { @@ -262,7 +269,7 @@ struct FMHAConfig { bool has_dropout = args.p_dropout > 0.0f; #define XE2_CASE(C, L, D) \ if (args.is_causal == C && args.is_local == L && has_dropout == D) { \ - run_xe2(queue, args); \ + run_xe2(queue, args); \ return; \ } XE2_CASE(false, false, false) @@ -279,7 +286,8 @@ struct FMHAConfig { // Single-dtype dispatch: only instantiates one dtype path per TU to reduce // per-file IGC memory usage from ~40 GB to ~20 GB. -template +template void policy_dispatch_fp16(sycl::queue& queue, const fmha_fwd_args_t& args) { using Config = FMHAConfig< typename chunk_policy::ShapeQK, @@ -289,10 +297,11 @@ void policy_dispatch_fp16(sycl::queue& queue, const fmha_fwd_args_t& args) { void, PipelineStages, half_t, half_t, half_t, half_t>; - Config::template xe2_dispatch(queue, args); + Config::template xe2_dispatch(queue, args); } -template +template void policy_dispatch_bf16(sycl::queue& queue, const fmha_fwd_args_t& args) { using Config = FMHAConfig< typename chunk_policy::ShapeQK, @@ -302,7 +311,7 @@ void policy_dispatch_bf16(sycl::queue& queue, const fmha_fwd_args_t& args) { void, PipelineStages, bfloat16_t, bfloat16_t, bfloat16_t, bfloat16_t>; - Config::template xe2_dispatch(queue, args); + Config::template xe2_dispatch(queue, args); } // Combined policy_dispatch is now defined inline in fmha_fwd.hpp diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp index ce3f857e..7730faca 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp @@ -55,6 +55,12 @@ struct fmha_fwd_args_t { int64_t vnew_batch_stride = 0; int64_t vnew_head_stride = 0; int64_t vnew_row_stride = 0; + + // Fused rotary embedding for kvcache append and Q load + void* rotary_cos = nullptr; + void* rotary_sin = nullptr; + int rotary_dim = 0; + bool is_rotary_interleaved = true; }; enum class CutlassType { diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp index 5928a4a1..ceb9209a 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp @@ -233,9 +233,9 @@ class XeFMHAFwdKernelXe2 { // - non-paged path: bidx selects the per-batch KV slice int bidx = (!PagedKV && p.cache_batch_idx) ? p.cache_batch_idx[idx_b] : idx_b; if (p.cache_seqlens) { - int orig_cache_seqlens = p.cache_seqlens[bidx]; + int orig_cache_seqlens = p.cache_seqlens[idx_b]; if (p.cache_leftpad) { - leftpad_k = p.cache_leftpad[bidx]; + leftpad_k = p.cache_leftpad[idx_b]; } // Fused cache update: copy knew/vnew into kcache/vcache @@ -275,7 +275,24 @@ class XeFMHAFwdKernelXe2 { + head * p.v_head_stride + static_cast(page_off) * p.v_row_stride; for (int d = thr_id; d < s.head_size_qk; d += num_threads) { - k_dst[d] = k_src[si * p.knew_row_stride + d]; + auto k_value = k_src[si * p.knew_row_stride + d]; + if constexpr (CollectiveMainloop::HasRotary) { + if (params.mainloop.rotary.rotary_dim > 0 && + params.mainloop.rotary.rotary_cos != nullptr && + params.mainloop.rotary.rotary_sin != nullptr && + d < params.mainloop.rotary.rotary_dim) { + int pair_dim = cutlass::fmha::collective::rotary_pair_dim( + d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + k_value = cutlass::fmha::collective::apply_rotary_scalar( + k_value, k_src[si * p.knew_row_stride + pair_dim], + params.mainloop.rotary.rotary_cos, + params.mainloop.rotary.rotary_sin, + global_pos, d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + } + } + k_dst[d] = k_value; } for (int d = thr_id; d < s.head_size_vo; d += num_threads) { v_dst[d] = v_src[si * p.vnew_row_stride + d]; @@ -290,7 +307,25 @@ class XeFMHAFwdKernelXe2 { + static_cast(orig_cache_seqlens) * p.v_row_stride; for (int si = 0; si < p.seqlen_knew; si++) { for (int d = thr_id; d < s.head_size_qk; d += num_threads) { - k_dst[si * p.k_row_stride + d] = k_src[si * p.knew_row_stride + d]; + auto k_value = k_src[si * p.knew_row_stride + d]; + if constexpr (CollectiveMainloop::HasRotary) { + if (params.mainloop.rotary.rotary_dim > 0 && + params.mainloop.rotary.rotary_cos != nullptr && + params.mainloop.rotary.rotary_sin != nullptr && + d < params.mainloop.rotary.rotary_dim) { + int pair_dim = cutlass::fmha::collective::rotary_pair_dim( + d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + k_value = cutlass::fmha::collective::apply_rotary_scalar( + k_value, k_src[si * p.knew_row_stride + pair_dim], + params.mainloop.rotary.rotary_cos, + params.mainloop.rotary.rotary_sin, + orig_cache_seqlens + si, d, + params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + } + } + k_dst[si * p.k_row_stride + d] = k_value; } } for (int si = 0; si < p.seqlen_knew; si++) { @@ -462,7 +497,9 @@ class XeFMHAFwdKernelXe2 { tile_row_idx, rows_of_maxima, head_q, - s.num_heads_q); + s.num_heads_q, + q_offset_sg, + p.cache_seqlens ? p.cache_seqlens[idx_b] : 0); if constexpr ( !is_empty_v && From b50ffa5ec5b3d6c3398df9af30b588aa3f150d70 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Sat, 9 May 2026 05:33:45 +0000 Subject: [PATCH 6/8] Remove fallback --- flash-attn2/build.toml | 1 - flash-attn2/flash_attn_xpu/flash_api.cpp | 81 ++---------- .../src/kernel/fmha_fwd_kernel_xe2.hpp | 30 +++-- flash-attn2/flash_attn_xpu/src/rotary.hpp | 118 ------------------ 4 files changed, 25 insertions(+), 205 deletions(-) delete mode 100644 flash-attn2/flash_attn_xpu/src/rotary.hpp diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index 3a050c2d..c9a63c4f 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -174,7 +174,6 @@ depends = [ src = [ "flash_attn_xpu/flash_api.cpp", "flash_attn_xpu/src/philox.hpp", - "flash_attn_xpu/src/rotary.hpp", "flash_attn_xpu/src/fmha_fwd_types.hpp", "flash_attn_xpu/src/fmha_fwd.hpp", "flash_attn_xpu/src/fmha_fwd_impl.hpp", diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 1ef3cfd8..5923a7dc 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -469,8 +469,6 @@ mha_varlen_fwd( return {out, softmax_lse, S_dmask, rng_state}; } -#include "src/rotary.hpp" - std::vector mha_fwd_kvcache( at::Tensor &q, @@ -622,77 +620,7 @@ mha_fwd_kvcache( TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); } - // Write new K/V to cache. - // - // Strategy: - // - Always prefer kernel-fused scatter (passes knew/vnew to the kernel, - // which writes them in-place during the prologue). This avoids any - // host sync and works for both contiguous and paged caches. - // - Fall back to API-layer scatter only when padding is needed: the - // padded cache is a separate buffer, so the in-kernel writer would - // not update the user tensor. - bool fuse_knew = k_.has_value() && seqlen_new > 0 - && !needs_padding; - if (has_rotary && !fuse_knew) { - std::optional seqlen_offsets_opt; - if (seqlens_k_.has_value()) { seqlen_offsets_opt = seqlens_k; } - - bool is_local = (window_size_left >= 0); - if (is_causal || is_local) { - apply_rotary_emb_inplace(q_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); - } else { - auto q_shape = q_padded.sizes(); - auto q_reshaped = q_padded.view({q_shape[0], 1, q_shape[1] * q_shape[2], q_shape[3]}); - apply_rotary_emb_inplace(q_reshaped, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); - } - apply_rotary_emb_inplace(k_padded, rotary_cos, rotary_sin, seqlen_offsets_opt, is_rotary_interleaved); - } - if (k_.has_value() && seqlen_new > 0 && !fuse_knew) { - auto seqlens_cpu = seqlens_k.to(torch::kCPU); - auto seqlens_accessor = seqlens_cpu.accessor(); - - at::Tensor k_for_cache = has_rotary - ? k_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), - torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}).contiguous() - : ensure_contiguous(k_.value()); - at::Tensor v_for_cache = ensure_contiguous(v_.value()); - - at::Tensor kc = ensure_contiguous(kcache); - at::Tensor vc = ensure_contiguous(vcache); - - if (paged_KV) { - auto bt_cpu = block_table.to(torch::kCPU); - auto bt_acc = bt_cpu.accessor(); - for (int b = 0; b < batch_size; b++) { - int cache_seqlen = seqlens_accessor[b]; - for (int s = 0; s < seqlen_new; s++) { - int global_pos = cache_seqlen + s; - int page_idx = global_pos / page_block_size; - int page_offset = global_pos % page_block_size; - int block_idx = bt_acc[b][page_idx]; - kc.index({block_idx, page_offset}) = k_for_cache.index({b, s}); - vc.index({block_idx, page_offset}) = v_for_cache.index({b, s}); - } - } - } else { - for (int b = 0; b < batch_size; b++) { - int cache_b = cache_batch_idx_.has_value() - ? cache_batch_idx.index({b}).item() : b; - int cache_seqlen = seqlens_accessor[b]; - int write_start = cache_seqlen; - TORCH_CHECK(write_start + seqlen_new <= seqlen_k, - "Cache overflow: cache_seqlen + seqlen_new > cache capacity"); - kc.index({cache_b, torch::indexing::Slice(write_start, write_start + seqlen_new)}) = - k_for_cache.index({b}); - vc.index({cache_b, torch::indexing::Slice(write_start, write_start + seqlen_new)}) = - v_for_cache.index({b}); - } - } - - kcache_padded = maybe_pad(kc); - vcache_padded = maybe_pad(vc); - seqlens_k = seqlens_k + seqlen_new; - } + bool fuse_knew = k_.has_value() && seqlen_new > 0; // Dispatch to kernel. Paged caches are now passed natively (block_table // routed straight through to the kernel, no host gather). @@ -748,6 +676,13 @@ mha_fwd_kvcache( torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}) .contiguous(); if (out_.has_value()) { out_.value().copy_(out); } + if (fuse_knew) { + // The fused kernel updates the padded cache buffer; publish valid dims back to the user cache. + kcache.copy_(kcache_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)})); + vcache.copy_(vcache_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)})); + } } return {out, softmax_lse}; diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp index ceb9209a..5f78f6b9 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp @@ -284,12 +284,14 @@ class XeFMHAFwdKernelXe2 { int pair_dim = cutlass::fmha::collective::rotary_pair_dim( d, params.mainloop.rotary.rotary_dim, params.mainloop.rotary.is_rotary_interleaved); - k_value = cutlass::fmha::collective::apply_rotary_scalar( - k_value, k_src[si * p.knew_row_stride + pair_dim], - params.mainloop.rotary.rotary_cos, - params.mainloop.rotary.rotary_sin, - global_pos, d, params.mainloop.rotary.rotary_dim, - params.mainloop.rotary.is_rotary_interleaved); + if (pair_dim < s.head_size_qk) { + k_value = cutlass::fmha::collective::apply_rotary_scalar( + k_value, k_src[si * p.knew_row_stride + pair_dim], + params.mainloop.rotary.rotary_cos, + params.mainloop.rotary.rotary_sin, + global_pos, d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + } } } k_dst[d] = k_value; @@ -316,13 +318,15 @@ class XeFMHAFwdKernelXe2 { int pair_dim = cutlass::fmha::collective::rotary_pair_dim( d, params.mainloop.rotary.rotary_dim, params.mainloop.rotary.is_rotary_interleaved); - k_value = cutlass::fmha::collective::apply_rotary_scalar( - k_value, k_src[si * p.knew_row_stride + pair_dim], - params.mainloop.rotary.rotary_cos, - params.mainloop.rotary.rotary_sin, - orig_cache_seqlens + si, d, - params.mainloop.rotary.rotary_dim, - params.mainloop.rotary.is_rotary_interleaved); + if (pair_dim < s.head_size_qk) { + k_value = cutlass::fmha::collective::apply_rotary_scalar( + k_value, k_src[si * p.knew_row_stride + pair_dim], + params.mainloop.rotary.rotary_cos, + params.mainloop.rotary.rotary_sin, + orig_cache_seqlens + si, d, + params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + } } } k_dst[si * p.k_row_stride + d] = k_value; diff --git a/flash-attn2/flash_attn_xpu/src/rotary.hpp b/flash-attn2/flash_attn_xpu/src/rotary.hpp deleted file mode 100644 index bff33875..00000000 --- a/flash-attn2/flash_attn_xpu/src/rotary.hpp +++ /dev/null @@ -1,118 +0,0 @@ -#pragma once - -#include -#include -#include - -/// Apply rotary embedding to Q or K tensor (interleaved mode) -template -struct ApplyRotaryInterleavedKernel { - scalar_t* x; - const scalar_t* cos; - const scalar_t* sin; - const int* seqlen_offsets; - int batch_size, seqlen, num_heads, head_dim, rotary_dim, cos_sin_stride; - - void operator()(sycl::nd_item<1> item) const { - int idx = item.get_global_id(0); - int total_elements = batch_size * seqlen * num_heads * (rotary_dim / 2); - if (idx >= total_elements) return; - int half_rotary = rotary_dim / 2; - int pair_idx = idx % half_rotary; - int temp = idx / half_rotary; - int head_idx = temp % num_heads; - temp = temp / num_heads; - int seq_idx = temp % seqlen; - int batch_idx = temp / seqlen; - int pos = (seqlen_offsets != nullptr) ? seqlen_offsets[batch_idx] + seq_idx : seq_idx; - float c = static_cast(cos[pos * cos_sin_stride + pair_idx]); - float s = static_cast(sin[pos * cos_sin_stride + pair_idx]); - int base_offset = ((batch_idx * seqlen + seq_idx) * num_heads + head_idx) * head_dim; - int x0_idx = base_offset + pair_idx * 2; - int x1_idx = base_offset + pair_idx * 2 + 1; - float x0 = static_cast(x[x0_idx]); - float x1 = static_cast(x[x1_idx]); - x[x0_idx] = static_cast(x0 * c - x1 * s); - x[x1_idx] = static_cast(x0 * s + x1 * c); - } -}; - -/// Apply rotary embedding (non-interleaved / GPT-NeoX style) -template -struct ApplyRotaryContiguousKernel { - scalar_t* x; - const scalar_t* cos; - const scalar_t* sin; - const int* seqlen_offsets; - int batch_size, seqlen, num_heads, head_dim, rotary_dim, cos_sin_stride; - - void operator()(sycl::nd_item<1> item) const { - int idx = item.get_global_id(0); - int total_elements = batch_size * seqlen * num_heads * (rotary_dim / 2); - if (idx >= total_elements) return; - int half_rotary = rotary_dim / 2; - int pair_idx = idx % half_rotary; - int temp = idx / half_rotary; - int head_idx = temp % num_heads; - temp = temp / num_heads; - int seq_idx = temp % seqlen; - int batch_idx = temp / seqlen; - int pos = (seqlen_offsets != nullptr) ? seqlen_offsets[batch_idx] + seq_idx : seq_idx; - float c = static_cast(cos[pos * cos_sin_stride + pair_idx]); - float s = static_cast(sin[pos * cos_sin_stride + pair_idx]); - int base_offset = ((batch_idx * seqlen + seq_idx) * num_heads + head_idx) * head_dim; - int x0_idx = base_offset + pair_idx; - int x1_idx = base_offset + pair_idx + half_rotary; - float x0 = static_cast(x[x0_idx]); - float x1 = static_cast(x[x1_idx]); - x[x0_idx] = static_cast(x0 * c - x1 * s); - x[x1_idx] = static_cast(x0 * s + x1 * c); - } -}; - -inline void apply_rotary_emb_inplace( - at::Tensor& x, - const at::Tensor& cos, - const at::Tensor& sin, - const std::optional& seqlen_offsets, - bool interleaved) { - auto batch_size = x.size(0); - auto seqlen = x.size(1); - auto num_heads = x.size(2); - auto head_dim = x.size(3); - auto rotary_dim = cos.size(1) * 2; - TORCH_CHECK(rotary_dim <= head_dim, "rotary_dim must be <= head_dim"); - auto queue = c10::xpu::getCurrentXPUStream().queue(); - int total_pairs = batch_size * seqlen * num_heads * (rotary_dim / 2); - int wg_size = 256; - int num_groups = (total_pairs + wg_size - 1) / wg_size; - if (interleaved) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::kBFloat16, at::kHalf, x.scalar_type(), "apply_rotary_interleaved", [&] { - const int* offset_ptr = seqlen_offsets.has_value() - ? seqlen_offsets->data_ptr() : nullptr; - ApplyRotaryInterleavedKernel kernel{ - x.data_ptr(), cos.data_ptr(), - sin.data_ptr(), offset_ptr, - (int)batch_size, (int)seqlen, (int)num_heads, - (int)head_dim, (int)rotary_dim, (int)cos.size(1)}; - queue.submit([&](sycl::handler& h) { - h.parallel_for(sycl::nd_range<1>(num_groups * wg_size, wg_size), kernel); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::kBFloat16, at::kHalf, x.scalar_type(), "apply_rotary_contiguous", [&] { - const int* offset_ptr = seqlen_offsets.has_value() - ? seqlen_offsets->data_ptr() : nullptr; - ApplyRotaryContiguousKernel kernel{ - x.data_ptr(), cos.data_ptr(), - sin.data_ptr(), offset_ptr, - (int)batch_size, (int)seqlen, (int)num_heads, - (int)head_dim, (int)rotary_dim, (int)cos.size(1)}; - queue.submit([&](sycl::handler& h) { - h.parallel_for(sycl::nd_range<1>(num_groups * wg_size, wg_size), kernel); - }); - }); - } -} From 0dcc9c598675980abd786f460c27c2680c67e6bc Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Sat, 9 May 2026 05:37:57 +0000 Subject: [PATCH 7/8] Remove comments --- flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp index 282f0bb7..e686170c 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp @@ -1,17 +1,3 @@ -/*************************************************************************************************** - * Copyright (C) 2025 Intel Corporation, All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Xe2 (BMG / Arc Pro B60) FMHA forward dispatch. Builds the full - * mainloop+epilogue+kernel for a given tile policy and feature combination, - * and launches it. - * - * The PVC path has been removed; only the Xe2 fork is built. Each per-head - * translation unit instantiates the (Causal x Local x Dropout) cases for a - * single (IsVarLen, IsPaged) combination, controlled by the IsVarLen / IsPaged - * template arguments to policy_dispatch. - **************************************************************************************************/ - #pragma once #include "fmha_fwd_types.hpp" From a7e48e23734ab4ae347aafb4b4dec7c794ced7f9 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Fri, 15 May 2026 03:11:33 +0000 Subject: [PATCH 8/8] Change build cores to 4 --- build-concurrency.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build-concurrency.json b/build-concurrency.json index 81bae71b..abb5747b 100644 --- a/build-concurrency.json +++ b/build-concurrency.json @@ -2,7 +2,7 @@ "flash-attn2": { "xpu": { "max-jobs": 1, - "cores": 8 + "cores": 4 } } }