diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index 0a4aff81..c9a63c4f 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -210,6 +210,22 @@ src = [ "flash_attn_xpu/src/flash_fwd_hdim256_fix_bf16.cpp", "flash_attn_xpu/src/flash_fwd_hdim512_fix_fp16.cpp", "flash_attn_xpu/src/flash_fwd_hdim512_fix_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp", + "flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp", "flash_attn_xpu/src/fmha_bwd_types.hpp", "flash_attn_xpu/src/fmha_bwd.hpp", "flash_attn_xpu/src/fmha_bwd_impl.hpp", diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index a4bb4c5c..5923a7dc 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #include "src/fmha_fwd.hpp" #include "src/fmha_bwd.hpp" @@ -466,6 +468,226 @@ mha_varlen_fwd( at::Tensor rng_state; return {out, softmax_lse, S_dmask, rng_state}; } + +std::vector +mha_fwd_kvcache( + at::Tensor &q, + const at::Tensor &kcache, + const at::Tensor &vcache, + std::optional &k_, + std::optional &v_, + std::optional &seqlens_k_, + std::optional &rotary_cos_, + std::optional &rotary_sin_, + std::optional &cache_batch_idx_, + std::optional &leftpad_k_, + std::optional &block_table_, + std::optional &alibi_slopes_, + std::optional &out_, + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + bool is_rotary_interleaved, + int num_splits) { + auto device_idx = q.device().index(); + compat::select_device(device_idx); + + TORCH_CHECK(!alibi_slopes_.has_value(), + "FlashAttention KVCache on XPU does not support alibi_slopes."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention KVCache only supports fp16 and bf16 data type"); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key cache must have the same dtype"); + TORCH_CHECK(vcache.dtype() == q_dtype, "query and value cache must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + TORCH_CHECK(q.stride(-1) == 1, "Query must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Key cache must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Value cache must have contiguous last dimension"); + + const bool paged_KV = block_table_.has_value(); + at::Tensor block_table; + if (paged_KV) { + TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + } + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int page_block_size = !paged_KV ? 1 : kcache.size(1); + const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; + const int num_heads_k = kcache.size(2); + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256 || head_size_og == 512, + "FlashAttention KVCache only supports head dimension up to 256 or exactly 512"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + const int head_size_padded = round_multiple(head_size_og, 32); + const bool needs_padding = (head_size_og != head_size_padded); + const int pad_size = head_size_padded - head_size_og; + + auto maybe_pad = [&](const at::Tensor& t) -> at::Tensor { + return needs_padding + ? torch::nn::functional::pad(t, torch::nn::functional::PadFuncOptions({0, pad_size})) + : t; + }; + + at::Tensor q_padded = maybe_pad(ensure_contiguous(q)); + at::Tensor kcache_padded = maybe_pad(ensure_contiguous(kcache)); + at::Tensor vcache_padded = maybe_pad(ensure_contiguous(vcache)); + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + if (needs_padding) { out = maybe_pad(out); } + } else { + out = torch::zeros_like(q_padded); + } + + auto opts = q.options(); + auto softmax_lse = torch::full({batch_size, num_heads, seqlen_q}, + -std::numeric_limits::infinity(), + opts.dtype(at::kFloat)); + + // Handle new K/V + at::Tensor k_padded, v_padded; + int seqlen_new = 0; + if (k_.has_value()) { + TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); + TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); + auto k = k_.value(); + auto v = v_.value(); + TORCH_CHECK(k.dtype() == q_dtype && v.dtype() == q_dtype); + CHECK_DEVICE(k); CHECK_DEVICE(v); + TORCH_CHECK(k.stride(-1) == 1 && v.stride(-1) == 1); + seqlen_new = k.size(1); + k_padded = maybe_pad(ensure_contiguous(k)); + v_padded = maybe_pad(ensure_contiguous(v)); + } + + at::Tensor seqlens_k; + if (seqlens_k_.has_value()) { + seqlens_k = seqlens_k_.value(); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + } + + at::Tensor rotary_cos, rotary_sin; + int rotary_dim = 0; + const bool has_rotary = rotary_cos_.has_value(); + if (has_rotary) { + TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key/value must also be provided"); + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + rotary_cos = ensure_contiguous(rotary_cos_.value()); + rotary_sin = ensure_contiguous(rotary_sin_.value()); + CHECK_DEVICE(rotary_cos); CHECK_DEVICE(rotary_sin); + rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim"); + TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + TORCH_CHECK(rotary_cos.scalar_type() == q_dtype && rotary_sin.scalar_type() == q_dtype); + } + + at::Tensor cache_batch_idx; + if (cache_batch_idx_.has_value()) { + cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32); + } + + at::Tensor leftpad_k; + if (leftpad_k_.has_value()) { + TORCH_CHECK(!paged_KV, "Paged KV and leftpad_k are not supported together"); + leftpad_k = leftpad_k_.value(); + CHECK_DEVICE(leftpad_k); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + } + + bool fuse_knew = k_.has_value() && seqlen_new > 0; + + // Dispatch to kernel. Paged caches are now passed natively (block_table + // routed straight through to the kernel, no host gather). + auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue(); + const bool is_local = (window_size_left >= 0); + + std::optional cache_batch_idx_opt; + if (cache_batch_idx_.has_value()) { + cache_batch_idx_opt = cache_batch_idx; + } + + std::optional leftpad_k_opt; + if (leftpad_k_.has_value()) { + leftpad_k_opt = leftpad_k; + } + + // For paths where new KV is appended in-kernel, pass knew/vnew through. + std::optional knew_opt, vnew_opt; + if (fuse_knew) { + knew_opt = k_padded; + vnew_opt = v_padded; + } + + std::optional block_table_opt; + if (paged_KV) { + block_table_opt = block_table; + } + + std::optional rotary_cos_opt, rotary_sin_opt; + if (fuse_knew && has_rotary) { + rotary_cos_opt = rotary_cos; + rotary_sin_opt = rotary_sin; + } + + cutlass_fmha_fwd_kvcache_impl( + queue, + q_padded, kcache_padded, vcache_padded, + out, softmax_lse, + seqlens_k, cache_batch_idx_opt, leftpad_k_opt, + knew_opt, vnew_opt, + block_table_opt, rotary_cos_opt, rotary_sin_opt, + fuse_knew ? rotary_dim : 0, is_rotary_interleaved, seqlen_k, + softmax_scale, window_size_left, window_size_right, + is_causal, is_local); + + // Update seqlens_k after kernel completes (for fused scatter path) + if (fuse_knew) { + seqlens_k = seqlens_k + seqlen_new; + } + + if (needs_padding) { + out = out.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}) + .contiguous(); + if (out_.has_value()) { out_.value().copy_(out); } + if (fuse_knew) { + // The fused kernel updates the padded cache buffer; publish valid dims back to the user cache. + kcache.copy_(kcache_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)})); + vcache.copy_(vcache_padded.index({torch::indexing::Slice(), torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)})); + } + } + + return {out, softmax_lse}; +} + } // namespace FLASH_NAMESPACE // std::tuple @@ -633,4 +855,62 @@ mha_bwd(const torch::Tensor &dout, gen_, rng_opt ); +} + +std::vector +mha_fwd_kvcache( + const torch::Tensor &q, + const torch::Tensor &kcache, + const torch::Tensor &vcache, + const c10::optional &k_, + const c10::optional &v_, + const c10::optional &seqlens_k_, + const c10::optional &rotary_cos_, + const c10::optional &rotary_sin_, + const c10::optional &cache_batch_idx_, + const c10::optional &leftpad_k_, + const c10::optional &block_table_, + const c10::optional &alibi_slopes_, + const c10::optional &out_, + const double softmax_scale, + bool is_causal, + const int64_t window_size_left, + const int64_t window_size_right, + const double softcap, + bool is_rotary_interleaved, + const int64_t num_splits) { + // Convert c10::optional -> std::optional for the internal API + auto to_std_opt = [](const c10::optional& opt) -> std::optional { + return opt.has_value() ? std::optional(opt.value()) : std::nullopt; + }; + auto to_std_opt_const = [](const c10::optional& opt) -> std::optional { + return opt.has_value() ? std::optional(opt.value()) : std::nullopt; + }; + + at::Tensor q_mut = q; + auto k_opt = to_std_opt_const(k_); + auto v_opt = to_std_opt_const(v_); + auto seqlens_opt = to_std_opt_const(seqlens_k_); + auto rotary_cos_opt = to_std_opt_const(rotary_cos_); + auto rotary_sin_opt = to_std_opt_const(rotary_sin_); + auto cache_batch_idx_opt = to_std_opt_const(cache_batch_idx_); + auto leftpad_k_opt = to_std_opt_const(leftpad_k_); + auto block_table_opt = to_std_opt(block_table_); + auto alibi_opt = to_std_opt(alibi_slopes_); + auto out_opt = to_std_opt(out_); + + return FLASH_NAMESPACE::mha_fwd_kvcache( + q_mut, kcache, vcache, + k_opt, v_opt, seqlens_opt, + rotary_cos_opt, rotary_sin_opt, + cache_batch_idx_opt, leftpad_k_opt, + block_table_opt, alibi_opt, out_opt, + static_cast(softmax_scale), + is_causal, + static_cast(window_size_left), + static_cast(window_size_right), + static_cast(softcap), + is_rotary_interleaved, + static_cast(num_splits) + ); } \ No newline at end of file diff --git a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp index 7e8c0ea5..b992343a 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp @@ -22,6 +22,48 @@ namespace cutlass::fmha::collective { using namespace cute; +template +CUTLASS_DEVICE Element apply_rotary_scalar( + Element x, + Element x_pair, + const RotaryElement* cos, + const RotaryElement* sin, + int position, + int dim, + int rotary_dim, + bool interleaved) { + if (rotary_dim == 0 || dim >= rotary_dim) { + return x; + } + + int half_rotary = rotary_dim / 2; + int cos_sin_idx = interleaved ? dim / 2 + : (dim < half_rotary ? dim : dim - half_rotary); + bool is_second = interleaved ? (dim % 2) : (dim >= half_rotary); + + float x_f = static_cast(x); + float x_pair_f = static_cast(x_pair); + float c = static_cast(cos[position * half_rotary + cos_sin_idx]); + float s = static_cast(sin[position * half_rotary + cos_sin_idx]); + float rotated = is_second ? x_pair_f * s + x_f * c + : x_f * c - x_pair_f * s; + return static_cast(rotated); +} + +CUTLASS_DEVICE int rotary_pair_dim( + int dim, + int rotary_dim, + bool interleaved) { + if (dim >= rotary_dim) { + return dim; + } + if (interleaved) { + return dim ^ 1; + } + int half_rotary = rotary_dim / 2; + return dim < half_rotary ? dim + half_rotary : dim - half_rotary; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // // FMHAFwdMainloopTraits: common type aliases derived from TiledMMA / VTiles. diff --git a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp index 4cff3697..8f7a5b1e 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_mainloop_xe2.hpp @@ -47,7 +47,8 @@ template < class TensorV_, class TiledCopyQ_ = void, class TiledCopyK_ = void, - class TiledCopyV_ = void> + class TiledCopyV_ = void, + bool HasRotary_ = false> struct FMHAFwdMainloopXe2 { static_assert( cutlass::detail::dependent_false, @@ -70,7 +71,8 @@ template < class TensorV_, class TiledCopyQ_, class TiledCopyK_, - class TiledCopyV_> + class TiledCopyV_, + bool HasRotary_> struct FMHAFwdMainloopXe2< Xe2, CausalMask_, @@ -85,7 +87,8 @@ struct FMHAFwdMainloopXe2< TensorV_, TiledCopyQ_, TiledCopyK_, - TiledCopyV_> { + TiledCopyV_, + HasRotary_> { // Pull in common type aliases from the shared traits. using Traits = FMHAFwdMainloopTraits< @@ -123,6 +126,7 @@ struct FMHAFwdMainloopXe2< static constexpr bool LocalMask = LocalMask_; static constexpr bool HasDropout = HasDropout_; static constexpr bool PagedKV = PagedKV_; + static constexpr bool HasRotary = HasRotary_; // User-facing arguments struct Arguments { @@ -142,6 +146,10 @@ struct FMHAFwdMainloopXe2< int page_size = 0; int max_pages_per_seq = 0; int total_seqlen_kv = 0; + const typename TensorQ::element_type* rotary_cos = nullptr; + const typename TensorQ::element_type* rotary_sin = nullptr; + int rotary_dim = 0; + bool is_rotary_interleaved = true; }; struct LocalMaskFields { @@ -165,6 +173,14 @@ struct FMHAFwdMainloopXe2< }; struct EmptyPaged {}; + struct RotaryFields { + const typename TensorQ::element_type* rotary_cos = nullptr; + const typename TensorQ::element_type* rotary_sin = nullptr; + int rotary_dim = 0; + bool is_rotary_interleaved = true; + }; + struct EmptyRotary {}; + // Kernel-facing parameters struct Params { ElementS scale; @@ -174,6 +190,8 @@ struct FMHAFwdMainloopXe2< dropout_fields; [[no_unique_address]] conditional_t paged; + [[no_unique_address]] conditional_t + rotary; }; // SLM data @@ -209,6 +227,10 @@ struct FMHAFwdMainloopXe2< p.paged = {args.ptr_page_table, args.page_size, args.max_pages_per_seq, args.total_seqlen_kv}; } + if constexpr (HasRotary) { + p.rotary = {args.rotary_cos, args.rotary_sin, + args.rotary_dim, args.is_rotary_interleaved}; + } return p; } @@ -236,7 +258,9 @@ struct FMHAFwdMainloopXe2< int& tile_row_idx, const int& rows_of_maxima, int head_q, - int num_heads) { + int num_heads, + int q_offset_sg, + int rotary_base) { using namespace sycl::ext::oneapi::this_work_item; auto tile_shape_v = @@ -288,10 +312,13 @@ struct FMHAFwdMainloopXe2< auto prefetch_q = make_block_2d_prefetch(copy_q); auto prefetch_k = make_block_2d_prefetch(copy_k); auto prefetch_v = make_block_2d_prefetch(copy_v); + auto prefetch_v_paged = + make_block_2d_prefetch(tile_shape_v, V_2D); auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV_split); + auto pVgV_paged = prefetch_v_paged.get_slice(thr_id).partition_S(gV); // PagedKV: translate logical K index to physical page-tile index. int tiles_per_page = 0; @@ -311,20 +338,27 @@ struct FMHAFwdMainloopXe2< for (int D = 0; D < size<3>(pQgQ); D++) { prefetch(prefetch_q, pQgQ(_, _, _, D)); } - int prefetch_k_stages = (total_blk < Stages ? total_blk : Stages); - for (int D = 0; D < size<4>(pKgK); D++) { + if constexpr (PagedKV) { CUTLASS_PRAGMA_UNROLL - for (int K = blk_k0; K < blk_k0 + prefetch_k_stages; K++) { - int pk; - if constexpr (PagedKV) { - int ploc = K * get<1>(TileShapeQK{}) / params.paged.page_size; - pk = params.paged.ptr_page_table[b_offset + ploc] * - tiles_per_page + - K % tiles_per_page; - } else { - pk = K; + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, next_page_idx, D)); + } + } else { + int prefetch_k_stages = (total_blk < Stages ? total_blk : Stages); + for (int D = 0; D < size<4>(pKgK); D++) { + CUTLASS_PRAGMA_UNROLL + for (int K = blk_k0; K < blk_k0 + prefetch_k_stages; K++) { + int pk; + if constexpr (PagedKV) { + int ploc = K * get<1>(TileShapeQK{}) / params.paged.page_size; + pk = params.paged.ptr_page_table[b_offset + ploc] * + tiles_per_page + + K % tiles_per_page; + } else { + pk = K; + } + prefetch(prefetch_k, pKgK(_, _, _, pk, D)); } - prefetch(prefetch_k, pKgK(_, _, _, pk, D)); } } clear(tArA); @@ -361,9 +395,11 @@ struct FMHAFwdMainloopXe2< auto tVgV_cache = PagedKV ? tVgV(_, _, _, _, page_idx) : tVgV(_, _, _, _, K); - // Non-paged: prefetch V before GEMM1 to overlap with computation. - // Paged: prefetch V after GEMM1 to avoid BMG hardware hang. - if constexpr (!PagedKV) { + // Paged path uses the old whole-V-tile prefetch pattern; split-V + // prefetch-after-GEMM can hang on BMG for some paged cases. + if constexpr (PagedKV) { + prefetch(prefetch_v_paged, pVgV_paged(_, _, _, page_idx)); + } else if constexpr (!PagedKV) { CUTLASS_PRAGMA_UNROLL for (int VV = 0; VV < VTiles; VV++) { prefetch(prefetch_v, pVgV(_, _, _, VV, K)); @@ -375,19 +411,44 @@ struct FMHAFwdMainloopXe2< CUTLASS_PRAGMA_UNROLL for (int D = 0; D < size<4>(tKgK); D++) { copy(copy_q, tQgQ(_, _, _, D), tQrQ); + if constexpr (HasRotary) { + if (params.rotary.rotary_dim > 0 && + params.rotary.rotary_cos != nullptr && + params.rotary.rotary_sin != nullptr) { + auto tQrQ_coords = tQrQ.tv_layout(); + int lane_id = static_cast(get_sub_group().get_local_linear_id()); + int q_tile_base = get<0>(blk_qv) * get<0>(TileShapeQK{}) + q_offset_sg; + int dim_tile_base = D * get<2>(TileShapeQK{}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tQrQ.size(); ++i) { + auto value_coord = idx2crd( + i, make_shape( + get<1>(shape(tQrQ_coords)), + get<2>(shape(tQrQ_coords)))); + auto coord = tQrQ_coords( + make_coord(lane_id, get<0>(value_coord), get<1>(value_coord))); + int row = q_tile_base + get<0>(coord); + int dim = dim_tile_base + get<1>(coord); + if (row < seq_len_qo && dim < params.rotary.rotary_dim) { + int pair_dim = rotary_pair_dim( + dim, params.rotary.rotary_dim, + params.rotary.is_rotary_interleaved); + int position = rotary_base + ((CausalMask || LocalMask) ? row : 0); + tQrQ(i) = apply_rotary_scalar( + tQrQ(i), Q_2D(row, pair_dim), params.rotary.rotary_cos, + params.rotary.rotary_sin, position, dim, + params.rotary.rotary_dim, + params.rotary.is_rotary_interleaved); + } + } + } + } copy(copy_k, tKgK_cache(_, _, _, D), tKrK); reorder(tQrQ, tSrQ); reorder(tKrK, tSrK); cute::gemm(mma_qk, tSrQ, tSrK, tSrS); } - if constexpr (PagedKV) { - CUTLASS_PRAGMA_UNROLL - for (int VV = 0; VV < VTiles; VV++) { - prefetch(prefetch_v, pVgV(_, _, _, VV, page_idx)); - } - } - auto cS_thread = thr_mma_qk.partition_C(gP_all(_, _, K)); if (check_remainder_k && K == blk_k1 - 1) { @@ -522,28 +583,35 @@ struct FMHAFwdMainloopXe2< cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV)); } - int K_next = K + Stages; - if (K_next < blk_k1) { - if constexpr (PagedKV) { - int next_page_local_idx = - K_next * get<1>(TileShapeQK{}) / params.paged.page_size; - int pk_next; - if (next_page_local_idx < params.paged.max_pages_per_seq) { - pk_next = - params.paged.ptr_page_table[b_offset + next_page_local_idx] * - tiles_per_page + - K_next % tiles_per_page; + if constexpr (PagedKV) { + CUTLASS_PRAGMA_UNROLL + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, next_page_idx, D)); + } + } else { + int K_next = K + Stages; + if (K_next < blk_k1) { + if constexpr (PagedKV) { + int next_page_local_idx = + K_next * get<1>(TileShapeQK{}) / params.paged.page_size; + int pk_next; + if (next_page_local_idx < params.paged.max_pages_per_seq) { + pk_next = + params.paged.ptr_page_table[b_offset + next_page_local_idx] * + tiles_per_page + + K_next % tiles_per_page; + } else { + pk_next = params.paged.max_pages_per_seq * tiles_per_page - 1; + } + CUTLASS_PRAGMA_UNROLL + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, pk_next, D)); + } } else { - pk_next = params.paged.max_pages_per_seq * tiles_per_page - 1; - } - CUTLASS_PRAGMA_UNROLL - for (int D = 0; D < size<4>(pKgK); D++) { - prefetch(prefetch_k, pKgK(_, _, _, pk_next, D)); - } - } else { - CUTLASS_PRAGMA_UNROLL - for (int D = 0; D < size<4>(pKgK); D++) { - prefetch(prefetch_k, pKgK(_, _, _, K_next, D)); + CUTLASS_PRAGMA_UNROLL + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, K_next, D)); + } } } } diff --git a/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh b/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh index 90036e3b..aebd3c85 100755 --- a/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh +++ b/flash-attn2/flash_attn_xpu/src/create_instantiation_files.sh @@ -88,8 +88,68 @@ ENDFILE done done +echo "" +echo "Creating kvcache-paged instantiation files (split by dtype)..." +for hdim in "${HDIMS[@]}"; do + for dtype in fp16 bf16; do + cat > flash_fwd_hdim${hdim}_kvcache_paged_${dtype}.cpp << ENDFILE +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=${dtype} +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_${dtype}< + prefill_policy_head${hdim}, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_${dtype}< + decode_paged_policy_head${hdim}, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_${dtype}< + prefill_policy_head${hdim}, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_${dtype}< + decode_policy_head${hdim}, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_${dtype}< + prefill_policy_head${hdim}, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_${dtype}< + decode_paged_policy_head${hdim}, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); +ENDFILE + echo " Created flash_fwd_hdim${hdim}_kvcache_paged_${dtype}.cpp" + done +done + echo "" echo "✓ All instantiation files created successfully!" echo " - $((${#HDIMS[@]} * 2)) varlen files (split by dtype)" echo " - $((${#HDIMS[@]} * 2)) fixed files (split by dtype)" -echo " Total: $((${#HDIMS[@]} * 4)) files" +echo " - $((${#HDIMS[@]} * 2)) kvcache_paged files (split by dtype)" +echo " Total: $((${#HDIMS[@]} * 6)) files" diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..d19f10b4 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head128, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..3676037a --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head128, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head128, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head128, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..63fd6b19 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head160, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..cb8cda7d --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head160, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head160, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head160, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..14346625 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head192, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..a34c90cd --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head192, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head192, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head192, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..4d9f0dde --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head256, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..3e3b8303 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head256, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head256, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head256, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..f5d45308 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head32, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..3a09b49f --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head32, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head32, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head32, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..2a25b6e9 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head512, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..31b7d308 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head512, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head512, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head512, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..a55d84e9 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head64, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..dd221976 --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head64, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head64, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head64, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp new file mode 100644 index 00000000..0e270ada --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=bf16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_bf16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_bf16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_bf16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_policy_head96, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_bf16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp new file mode 100644 index 00000000..5cd8032e --- /dev/null +++ b/flash-attn2/flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp @@ -0,0 +1,49 @@ +#include "fmha_fwd_impl.hpp" + +// Non-varlen + paged: IsVarLen=0, IsPaged=1, dtype=fp16 +// Used by mha_fwd_kvcache when block_table is provided. + +// Prefill paged +template void policy_dispatch_fp16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Decode paged (smaller K-tile to fit page boundaries) +template void policy_dispatch_fp16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +// Rotary kvcache variants keep rotary code out of non-rotary kernels. +template void policy_dispatch_fp16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_policy_head96, + PipelineStages_Decode, + 0, 0, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + prefill_policy_head96, + PipelineStages_Prefill, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); + +template void policy_dispatch_fp16< + decode_paged_policy_head96, + PipelineStages_Decode, + 0, 1, true>( + sycl::queue& queue, + const fmha_fwd_args_t& args); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp index 640d43c0..ea7ac3ba 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_bwd_impl.hpp @@ -605,7 +605,7 @@ dq_dk_dv_1colblock(Trait &trait, BwdParam ¶m, // Math: dS = scale * P * (mask * rp * dP_dropped - dpsum) // where P is the original softmax output, dP_dropped = dO * V^T, // and dpsum = sum(dO * O) with O having rp scaling from forward. - constexpr int kDropMaskMax = 128; + constexpr int kDropMaskMax = decltype(size<0>(scores))::value * decltype(size<1>(scores))::value; bool drop_keep[kDropMaskMax]; int drop_mask_count = 0; diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp index 7e52e200..27d81d0e 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.cpp @@ -55,6 +55,19 @@ void dispatch_varlen_decode_paged(sycl::queue& queue, CutlassType cuType, } } +template +void dispatch_kvcache_policy(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, + bool has_rotary) { + if (has_rotary) { + policy_dispatch(queue, cuType, args); + } else { + policy_dispatch(queue, cuType, args); + } +} + /// Dispatch forward kernel by head_size for the varlen prefill path. /// Supported head dimensions are bucketed to the corresponding prefill policies. void dispatch_fwd_varlen_by_head(sycl::queue& queue, CutlassType cuType, @@ -112,6 +125,67 @@ void dispatch_fwd_prefill_by_head(sycl::queue& queue, CutlassType cuType, else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); } +/// Dispatch forward kernel by head_size for the kvcache prefill paged path +/// (non-varlen, seqlen_q > 1, IsPaged=1). +void dispatch_fwd_kvcache_prefill_paged_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + +/// Dispatch forward kernel by head_size for the kvcache decode paged path +/// (non-varlen, seqlen_q == 1, IsPaged=1). Uses decode_paged_policy with a +/// smaller K-tile so block_size=64 multiples don't overshoot a page. +void dispatch_fwd_kvcache_decode_paged_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + +void dispatch_fwd_kvcache_decode_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + +void dispatch_fwd_kvcache_prefill_by_head(sycl::queue& queue, CutlassType cuType, + const fmha_fwd_args_t& args, int head_size, + bool has_rotary) { + if (head_size <= 32) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 64) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 96) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 128) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 160) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 192) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size <= 256) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else if (head_size == 512) dispatch_kvcache_policy(queue, cuType, args, has_rotary); + else throw std::runtime_error("Unsupported head_size: " + std::to_string(head_size) + ". Only <= 256 or exactly 512 is supported"); +} + /// Clamp window sizes for local attention and fold causal into local when both are set. void normalize_window_params(int& window_size_left, int& window_size_right, bool& is_causal, bool is_local, int max_seqlen_k) { @@ -306,3 +380,142 @@ void cutlass_fmha_fwd_fix_impl( dispatch_fwd_prefill_by_head(queue, cuType, args, h); } } + +void cutlass_fmha_fwd_kvcache_impl( + sycl::queue& queue, + const at::Tensor& query, + const at::Tensor& kcache, + const at::Tensor& vcache, + at::Tensor& out, + at::Tensor& softmax_lse, + const at::Tensor& seqlens_k, + const std::optional& cache_batch_idx, + const std::optional& cache_leftpad, + const std::optional& knew, + const std::optional& vnew, + const std::optional& block_table, + const std::optional& rotary_cos, + const std::optional& rotary_sin, + int rotary_dim, + bool is_rotary_interleaved, + int max_seqlen_k_paged, + float sm_scale, + int window_size_left, + int window_size_right, + bool is_causal, + bool is_local) { + const bool is_paged = block_table.has_value() && block_table->defined(); + + const int batch_size = query.size(0); + const int max_seqlen_q = query.size(1); + const int num_heads_q = query.size(2); + const int head_size = query.size(3); + + int max_seqlen_k; + int num_heads_kv; + int num_blocks = 0; + int block_size = 0; + int max_blocks_per_seq = 0; + if (is_paged) { + // Paged layout: kcache is (num_blocks, page_block_size, num_heads_kv, head_size) + num_blocks = kcache.size(0); + block_size = kcache.size(1); + num_heads_kv = kcache.size(2); + max_blocks_per_seq = block_table->size(1); + max_seqlen_k = max_seqlen_k_paged; + TORCH_CHECK(num_blocks * block_size >= max_seqlen_k, + "Paged KV pool too small for max_seqlen_k: ", + num_blocks * block_size, " < ", max_seqlen_k); + TORCH_CHECK(!cache_batch_idx.has_value(), + "Paged KVcache does not support cache_batch_idx"); + } else { + max_seqlen_k = kcache.size(1); + num_heads_kv = kcache.size(2); + } + + const int total_seqlen_q = batch_size * max_seqlen_q; + const int total_seqlen_k = is_paged + ? num_blocks * block_size + : batch_size * max_seqlen_k; + + normalize_window_params(window_size_left, window_size_right, + is_causal, is_local, max_seqlen_k); + + fmha_fwd_args_t args = { + query.data_ptr(), + kcache.data_ptr(), + vcache.data_ptr(), + out.data_ptr(), + softmax_lse.data_ptr(), + is_paged ? block_table->data_ptr() : nullptr, + nullptr, // cu_seqlens_q + nullptr, // cu_seqlens_k + max_seqlen_q, + max_seqlen_k, + total_seqlen_q, + total_seqlen_k, + sm_scale, + batch_size, + num_heads_q, + num_heads_kv, + head_size, + max_blocks_per_seq, + block_size, + window_size_left, + window_size_right, + false, // is_varlen + is_paged, + is_causal, + is_local, + 0.0f, 0, 0, nullptr, nullptr, 0, 0, // dropout & s_dmask defaults + static_cast(seqlens_k.data_ptr()), + cache_batch_idx.has_value() + ? static_cast(cache_batch_idx->data_ptr()) + : nullptr, + cache_leftpad.has_value() + ? static_cast(cache_leftpad->data_ptr()) + : nullptr, + knew.has_value() ? knew->data_ptr() : nullptr, + vnew.has_value() ? vnew->data_ptr() : nullptr, + knew.has_value() ? static_cast(knew->size(1)) : 0, + knew.has_value() ? knew->stride(0) : 0, + knew.has_value() ? knew->stride(2) : 0, + knew.has_value() ? knew->stride(1) : 0, + vnew.has_value() ? vnew->stride(0) : 0, + vnew.has_value() ? vnew->stride(2) : 0, + vnew.has_value() ? vnew->stride(1) : 0, + rotary_cos.has_value() ? rotary_cos->data_ptr() : nullptr, + rotary_sin.has_value() ? rotary_sin->data_ptr() : nullptr, + rotary_dim, + is_rotary_interleaved}; + + const CutlassType cuType = aten_to_Cutlass_dtype(query); + const int h = args.head_size; + const bool has_rotary = args.rotary_dim > 0 && args.rotary_cos != nullptr && + args.rotary_sin != nullptr; + + if (is_paged) { + // Paged dispatch requires the K-tile to evenly divide block_size, + // because each tile is loaded from a single page. + const int k_tile_n = paged_k_tile_size_n(h, max_seqlen_q); + TORCH_CHECK(k_tile_n > 0, + "Unsupported head_size for paged FA2 kvcache: ", h); + TORCH_CHECK(is_supported_paged_block_size(block_size), + "Unsupported paged KV block_size=", block_size, + ". Supported values are positive multiples of 64."); + TORCH_CHECK(block_size % k_tile_n == 0, + "Paged KV block_size must be a multiple of the kernel K tile. " + "Got block_size=", block_size, ", K tile=", k_tile_n); + if (max_seqlen_q == 1) { + dispatch_fwd_kvcache_decode_paged_by_head(queue, cuType, args, h, has_rotary); + } else { + dispatch_fwd_kvcache_prefill_paged_by_head(queue, cuType, args, h, has_rotary); + } + } else { + if (max_seqlen_q == 1) { + dispatch_fwd_kvcache_decode_by_head(queue, cuType, args, h, has_rotary); + } else { + dispatch_fwd_kvcache_prefill_by_head(queue, cuType, args, h, has_rotary); + } + } +} diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp index baa3758d..ed1576f1 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd.hpp @@ -44,27 +44,30 @@ struct decode_paged_policy_head256; struct decode_paged_policy_head512; // Dtype-specific dispatch functions (instantiated in per-head TUs) -template +template void policy_dispatch_fp16( sycl::queue& queue, const fmha_fwd_args_t& args); -template +template void policy_dispatch_bf16( sycl::queue& queue, const fmha_fwd_args_t& args); // Combined dispatch (delegates to fp16/bf16 based on cuType) // Defined inline in header so callers (fmha_fwd.cpp) can see the template body. -template +template inline void policy_dispatch( sycl::queue& queue, CutlassType cuType, const fmha_fwd_args_t& args) { if (cuType == CutlassType::half) { - policy_dispatch_fp16(queue, args); + policy_dispatch_fp16(queue, args); } else { - policy_dispatch_bf16(queue, args); + policy_dispatch_bf16(queue, args); } } @@ -119,6 +122,23 @@ EXTERN_DISPATCH_FIX(256) EXTERN_DISPATCH_FIX(512) #undef EXTERN_DISPATCH_FIX +// KVCache + paged extern declarations (IsVarLen=0, IsPaged=1) +#define EXTERN_DISPATCH_KVCACHE_PAGED(HDIM) \ + extern template void policy_dispatch_fp16(sycl::queue&, const fmha_fwd_args_t&); \ + extern template void policy_dispatch_fp16(sycl::queue&, const fmha_fwd_args_t&); \ + extern template void policy_dispatch_bf16(sycl::queue&, const fmha_fwd_args_t&); \ + extern template void policy_dispatch_bf16(sycl::queue&, const fmha_fwd_args_t&); + +EXTERN_DISPATCH_KVCACHE_PAGED(32) +EXTERN_DISPATCH_KVCACHE_PAGED(64) +EXTERN_DISPATCH_KVCACHE_PAGED(96) +EXTERN_DISPATCH_KVCACHE_PAGED(128) +EXTERN_DISPATCH_KVCACHE_PAGED(160) +EXTERN_DISPATCH_KVCACHE_PAGED(192) +EXTERN_DISPATCH_KVCACHE_PAGED(256) +EXTERN_DISPATCH_KVCACHE_PAGED(512) +#undef EXTERN_DISPATCH_KVCACHE_PAGED + void cutlass_fmha_fwd_varlen_impl( sycl::queue& queue, const at::Tensor& query, @@ -158,3 +178,27 @@ void cutlass_fmha_fwd_fix_impl( void* s_dmask = nullptr, int seqlen_q_rounded = 0, int seqlen_k_rounded = 0); + +void cutlass_fmha_fwd_kvcache_impl( + sycl::queue& queue, + const at::Tensor& query, + const at::Tensor& kcache, + const at::Tensor& vcache, + at::Tensor& out, + at::Tensor& softmax_lse, + const at::Tensor& seqlens_k, + const std::optional& cache_batch_idx, + const std::optional& cache_leftpad, + const std::optional& knew, + const std::optional& vnew, + const std::optional& block_table, + const std::optional& rotary_cos, + const std::optional& rotary_sin, + int rotary_dim, + bool is_rotary_interleaved, + int max_seqlen_k_paged, + float sm_scale, + int window_size_left, + int window_size_right, + bool is_causal, + bool is_local); diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp index 5e6d00e4..e686170c 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_impl.hpp @@ -1,17 +1,3 @@ -/*************************************************************************************************** - * Copyright (C) 2025 Intel Corporation, All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Xe2 (BMG / Arc Pro B60) FMHA forward dispatch. Builds the full - * mainloop+epilogue+kernel for a given tile policy and feature combination, - * and launches it. - * - * The PVC path has been removed; only the Xe2 fork is built. Each per-head - * translation unit instantiates the (Causal x Local x Dropout) cases for a - * single (IsVarLen, IsPaged) combination, controlled by the IsVarLen / IsPaged - * template arguments to policy_dispatch. - **************************************************************************************************/ - #pragma once #include "fmha_fwd_types.hpp" @@ -72,7 +58,8 @@ struct FMHAConfig { bool Local, bool Dropout, bool Paged, - bool VarLen> + bool VarLen, + bool HasRotary> static void run_xe2(sycl::queue& queue, const fmha_fwd_args_t& args) { cutlass::KernelHardwareInfo hw_info; @@ -118,7 +105,8 @@ struct FMHAConfig { TensorV, GmemTiledCopyQ, GmemTiledCopyK, - GmemTiledCopyV>; + GmemTiledCopyV, + HasRotary>; using CollectiveEpilogue = cutlass::fmha::collective::FMHAFwdEpilogueXe2< CollectiveMainloop, @@ -195,7 +183,15 @@ struct FMHAConfig { o_batch_stride, o_head_stride, o_row_stride, reinterpret_cast(args.softmax_lse), lse_stride_head, - lse_stride_batch}, + lse_stride_batch, + args.cache_seqlens, + args.cache_batch_idx, + args.cache_leftpad, + reinterpret_cast(args.knew), + args.knew_batch_stride, args.knew_head_stride, args.knew_row_stride, + reinterpret_cast(args.vnew), + args.vnew_batch_stride, args.vnew_head_stride, args.vnew_row_stride, + args.seqlen_knew}, {static_cast(args.sm_scale), args.window_size_left, args.window_size_right, args.p_dropout, @@ -205,7 +201,11 @@ struct FMHAConfig { static_cast(args.block_table), args.block_size, args.max_blocks_per_seq, - args.total_seqlen_k}, + args.total_seqlen_k, + reinterpret_cast(args.rotary_cos), + reinterpret_cast(args.rotary_sin), + args.rotary_dim, + args.is_rotary_interleaved}, {}, hw_info}; @@ -241,6 +241,7 @@ struct FMHAConfig { template < int IsVarLen, int IsPaged, + bool HasRotary = false, class Scheduler = cutlass::fmha::kernel::XeFHMAIndividualTileScheduler> static void xe2_dispatch(sycl::queue& queue, const fmha_fwd_args_t& args) { @@ -254,7 +255,7 @@ struct FMHAConfig { bool has_dropout = args.p_dropout > 0.0f; #define XE2_CASE(C, L, D) \ if (args.is_causal == C && args.is_local == L && has_dropout == D) { \ - run_xe2(queue, args); \ + run_xe2(queue, args); \ return; \ } XE2_CASE(false, false, false) @@ -271,7 +272,8 @@ struct FMHAConfig { // Single-dtype dispatch: only instantiates one dtype path per TU to reduce // per-file IGC memory usage from ~40 GB to ~20 GB. -template +template void policy_dispatch_fp16(sycl::queue& queue, const fmha_fwd_args_t& args) { using Config = FMHAConfig< typename chunk_policy::ShapeQK, @@ -281,10 +283,11 @@ void policy_dispatch_fp16(sycl::queue& queue, const fmha_fwd_args_t& args) { void, PipelineStages, half_t, half_t, half_t, half_t>; - Config::template xe2_dispatch(queue, args); + Config::template xe2_dispatch(queue, args); } -template +template void policy_dispatch_bf16(sycl::queue& queue, const fmha_fwd_args_t& args) { using Config = FMHAConfig< typename chunk_policy::ShapeQK, @@ -294,7 +297,7 @@ void policy_dispatch_bf16(sycl::queue& queue, const fmha_fwd_args_t& args) { void, PipelineStages, bfloat16_t, bfloat16_t, bfloat16_t, bfloat16_t>; - Config::template xe2_dispatch(queue, args); + Config::template xe2_dispatch(queue, args); } // Combined policy_dispatch is now defined inline in fmha_fwd.hpp diff --git a/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp b/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp index 6d004e1b..7730faca 100644 --- a/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp +++ b/flash-attn2/flash_attn_xpu/src/fmha_fwd_types.hpp @@ -39,6 +39,28 @@ struct fmha_fwd_args_t { void* s_dmask = nullptr; // Output: attention matrix with dropout sign-bit encoding int seqlen_q_rounded = 0; // Q sequence length rounded up (stride for s_dmask) int seqlen_k_rounded = 0; // K sequence length rounded up (stride for s_dmask) + + // KV Cache fields (optional, set to nullptr for non-kvcache paths) + int* cache_seqlens = nullptr; // (batch_size,) per-batch effective KV length + int* cache_batch_idx = nullptr; // (batch_size,) indices into KV cache batch dim + int* cache_leftpad = nullptr; // (batch_size,) left padding per batch in KV cache + + // Fused KV cache append: new K/V to scatter into cache inside the kernel + void* knew = nullptr; + void* vnew = nullptr; + int seqlen_knew = 0; + int64_t knew_batch_stride = 0; + int64_t knew_head_stride = 0; + int64_t knew_row_stride = 0; + int64_t vnew_batch_stride = 0; + int64_t vnew_head_stride = 0; + int64_t vnew_row_stride = 0; + + // Fused rotary embedding for kvcache append and Q load + void* rotary_cos = nullptr; + void* rotary_sin = nullptr; + int rotary_dim = 0; + bool is_rotary_interleaved = true; }; enum class CutlassType { diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp index d594c7de..5f78f6b9 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fmha_fwd_kernel_xe2.hpp @@ -98,6 +98,20 @@ class XeFMHAFwdKernelXe2 { float* pLSE; int lse_stride_head; int lse_stride_batch; + // KV Cache: per-batch effective KV length (nullptr for non-kvcache paths) + int* cache_seqlens = nullptr; + int* cache_batch_idx = nullptr; + int* cache_leftpad = nullptr; + // Fused KV cache append + const ElementK* Knew = nullptr; + int64_t knew_batch_stride = 0; + int64_t knew_head_stride = 0; + int64_t knew_row_stride = 0; + const ElementV* Vnew = nullptr; + int64_t vnew_batch_stride = 0; + int64_t vnew_head_stride = 0; + int64_t vnew_row_stride = 0; + int seqlen_knew = 0; }; using KernelParams = KernelArguments; @@ -211,11 +225,132 @@ class XeFMHAFwdKernelXe2 { if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; - auto full_tile_offset = seq_len_kv - seq_len_qo; + // KV Cache: override seq_len_kv with per-batch effective length + int effective_seq_kv = seq_len_kv; + int leftpad_k = 0; + // bidx maps the logical request to a slot in the physical KV cache. + // - paged path: cache_batch_idx is forbidden, K/V are flat blocks + // - non-paged path: bidx selects the per-batch KV slice + int bidx = (!PagedKV && p.cache_batch_idx) ? p.cache_batch_idx[idx_b] : idx_b; + if (p.cache_seqlens) { + int orig_cache_seqlens = p.cache_seqlens[idx_b]; + if (p.cache_leftpad) { + leftpad_k = p.cache_leftpad[idx_b]; + } + + // Fused cache update: copy knew/vnew into kcache/vcache + if (p.Knew != nullptr && p.seqlen_knew > 0) { + constexpr int num_threads = SGPerWG::value * cute::intel::sg_size; + auto* k_src = p.Knew + + idx_b * p.knew_batch_stride + head * p.knew_head_stride; + auto* v_src = p.Vnew + + idx_b * p.vnew_batch_stride + head * p.vnew_head_stride; + if constexpr (PagedKV) { + // Paged scatter: per-token compute (block, page_offset) from + // block_table, then write to the corresponding page slot. + // Each "block" in the paged K/V tensor spans `page_size` rows, + // so the per-block byte stride is `page_size * row_stride`, + // not `k_batch_stride` (which is sized for the *whole* logical + // KV layout, not per-page). + const int page_size = params.mainloop.paged.page_size; + const int max_pages_per_seq = + params.mainloop.paged.max_pages_per_seq; + const int* page_table = params.mainloop.paged.ptr_page_table + + idx_b * max_pages_per_seq; + const int64_t k_block_stride = + static_cast(page_size) * p.k_row_stride; + const int64_t v_block_stride = + static_cast(page_size) * p.v_row_stride; + for (int si = 0; si < p.seqlen_knew; si++) { + int global_pos = orig_cache_seqlens + si; + int page_idx = global_pos / page_size; + int page_off = global_pos % page_size; + int block = page_table[page_idx]; + auto* k_dst = const_cast(p.K) + + static_cast(block) * k_block_stride + + head * p.k_head_stride + + static_cast(page_off) * p.k_row_stride; + auto* v_dst = const_cast(p.V) + + static_cast(block) * v_block_stride + + head * p.v_head_stride + + static_cast(page_off) * p.v_row_stride; + for (int d = thr_id; d < s.head_size_qk; d += num_threads) { + auto k_value = k_src[si * p.knew_row_stride + d]; + if constexpr (CollectiveMainloop::HasRotary) { + if (params.mainloop.rotary.rotary_dim > 0 && + params.mainloop.rotary.rotary_cos != nullptr && + params.mainloop.rotary.rotary_sin != nullptr && + d < params.mainloop.rotary.rotary_dim) { + int pair_dim = cutlass::fmha::collective::rotary_pair_dim( + d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + if (pair_dim < s.head_size_qk) { + k_value = cutlass::fmha::collective::apply_rotary_scalar( + k_value, k_src[si * p.knew_row_stride + pair_dim], + params.mainloop.rotary.rotary_cos, + params.mainloop.rotary.rotary_sin, + global_pos, d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + } + } + } + k_dst[d] = k_value; + } + for (int d = thr_id; d < s.head_size_vo; d += num_threads) { + v_dst[d] = v_src[si * p.vnew_row_stride + d]; + } + } + } else { + auto* k_dst = const_cast(p.K) + + bidx * p.k_batch_stride + head * p.k_head_stride + + static_cast(orig_cache_seqlens) * p.k_row_stride; + auto* v_dst = const_cast(p.V) + + bidx * p.v_batch_stride + head * p.v_head_stride + + static_cast(orig_cache_seqlens) * p.v_row_stride; + for (int si = 0; si < p.seqlen_knew; si++) { + for (int d = thr_id; d < s.head_size_qk; d += num_threads) { + auto k_value = k_src[si * p.knew_row_stride + d]; + if constexpr (CollectiveMainloop::HasRotary) { + if (params.mainloop.rotary.rotary_dim > 0 && + params.mainloop.rotary.rotary_cos != nullptr && + params.mainloop.rotary.rotary_sin != nullptr && + d < params.mainloop.rotary.rotary_dim) { + int pair_dim = cutlass::fmha::collective::rotary_pair_dim( + d, params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + if (pair_dim < s.head_size_qk) { + k_value = cutlass::fmha::collective::apply_rotary_scalar( + k_value, k_src[si * p.knew_row_stride + pair_dim], + params.mainloop.rotary.rotary_cos, + params.mainloop.rotary.rotary_sin, + orig_cache_seqlens + si, d, + params.mainloop.rotary.rotary_dim, + params.mainloop.rotary.is_rotary_interleaved); + } + } + } + k_dst[si * p.k_row_stride + d] = k_value; + } + } + for (int si = 0; si < p.seqlen_knew; si++) { + for (int d = thr_id; d < s.head_size_vo; d += num_threads) { + v_dst[si * p.v_row_stride + d] = v_src[si * p.vnew_row_stride + d]; + } + } + } + sycl::group_barrier(get_work_group<3>()); + effective_seq_kv = (orig_cache_seqlens + p.seqlen_knew) - leftpad_k; + } else { + effective_seq_kv = orig_cache_seqlens - leftpad_k; + } + } + if (effective_seq_kv <= 0) continue; + + auto full_tile_offset = effective_seq_kv - seq_len_qo; int seq_coord = cute::min( seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); int last_seq_coord = seq_coord + q_sg_tile - 1; - int first_non_masked_sequence = seq_len_qo - seq_len_kv; + int first_non_masked_sequence = seq_len_qo - effective_seq_kv; // Causal-only early-exit: skip SGs that are fully masked. With // LocalMask we can't easily do this here, so let the loop body mask. @@ -228,15 +363,15 @@ class XeFMHAFwdKernelXe2 { int seq_len; if constexpr (CausalMask && LocalMask) { seq_len = cute::min( - seq_len_kv, + effective_seq_kv, full_tile_offset + seq_coord + q_sg_tile + params.mainloop.local.local_right); } else if constexpr (CausalMask) { seq_len = calculate_longest_non_masked_length( - seq_len_kv, seq_len_qo, last_seq_coord, + effective_seq_kv, seq_len_qo, last_seq_coord, first_non_masked_sequence); } else { - seq_len = seq_len_kv; + seq_len = effective_seq_kv; } if (seq_len < 0) seq_len = 0; @@ -265,21 +400,29 @@ class XeFMHAFwdKernelXe2 { offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; } - auto batch_dim = is_var_len ? 1 : s.batch; + auto batch_dim_q = is_var_len ? 1 : s.batch; + // Paged KV is laid out as (num_blocks * page_size, head, num_heads_kv) + // with no batch dimension; treat it like the varlen K/V layout. + auto batch_dim_kv = (is_var_len || PagedKV) ? 1 : s.batch; int total_seqlen_kv; if constexpr (PagedKV) { total_seqlen_kv = params.mainloop.paged.total_seqlen_kv; } else { total_seqlen_kv = seq_len_kv; } + // When leftpad is applied, the K/V base pointer is shifted forward. + // Reduce the surface height so it accurately describes the data + // reachable from the shifted base, preventing 2D block loads from + // extending past the per-batch allocation. + int kv_surface_len = total_seqlen_kv - leftpad_k; auto shape_Q = - make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim); + make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim_q); auto shape_K = make_shape( - total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + kv_surface_len, s.head_size_qk, s.num_heads_kv, batch_dim_kv); auto shape_V = make_shape( - s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim); + s.head_size_vo, kv_surface_len, s.num_heads_kv, batch_dim_kv); auto shape_O = - make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim); + make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim_q); auto stride_q = cutlass::make_stride( static_cast(p.q_row_stride), Int<1>{}, @@ -303,6 +446,12 @@ class XeFMHAFwdKernelXe2 { auto dcV = const_cast(p.V + offset_v); auto ptrO = p.O + offset_o; + // Offset K/V by leftpad to skip left-padding tokens in the cache + if (leftpad_k > 0) { + dcK += leftpad_k * static_cast(p.k_row_stride); + dcV += leftpad_k * static_cast(p.v_row_stride); + } + auto layout_q = is_var_len ? make_ordered_layout(shape_Q, Step<_2, _0, _1, _3>{}) : make_layout(shape_Q, stride_q); @@ -327,12 +476,16 @@ class XeFMHAFwdKernelXe2 { int rows_of_maxima = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); - int l_coord = is_var_len ? 0 : idx_b; + // For non-paged KV, reuse cache_batch_idx remap (bidx) so that the + // KV slice matches the per-request seqlen. Q/O always use idx_b. + int l_coord_q = is_var_len ? 0 : idx_b; + int l_coord_kv = is_var_len ? 0 + : (PagedKV ? 0 : bidx); CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); mainloop( - Q(_, _, head_q, l_coord), - K(_, _, head, l_coord), - V(_, _, head, l_coord), + Q(_, _, head_q, l_coord_q), + K(_, _, head, l_coord_kv), + V(_, _, head, l_coord_kv), tArA, tA_max, tA_sum, @@ -343,12 +496,14 @@ class XeFMHAFwdKernelXe2 { thr_id, seq_len, seq_len_qo, - seq_len_kv, + effective_seq_kv, idx_b, tile_row_idx, rows_of_maxima, head_q, - s.num_heads_q); + s.num_heads_q, + q_offset_sg, + p.cache_seqlens ? p.cache_seqlens[idx_b] : 0); if constexpr ( !is_empty_v && @@ -367,7 +522,7 @@ class XeFMHAFwdKernelXe2 { tile_row_idx, rows_of_maxima); epilogue( - O(_, _, head_q, l_coord), + O(_, _, head_q, l_coord_q), tArA, tA_max, tA_sum, diff --git a/flash-attn2/tests/test_flash_attn.py b/flash-attn2/tests/test_flash_attn.py index 5b138722..6fc65305 100644 --- a/flash-attn2/tests/test_flash_attn.py +++ b/flash-attn2/tests/test_flash_attn.py @@ -2026,7 +2026,8 @@ def test_flash_attn_kvcache( if device == "cpu": pytest.skip("kvcache not supported on CPU") if device == "xpu": - pytest.skip("kvcache not supported on xpu currently") + if alibi: + pytest.skip("alibi not supported on xpu currently") if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: diff --git a/flash-attn2/torch-ext/torch_binding.cpp b/flash-attn2/torch-ext/torch_binding.cpp index 37403521..b277ea46 100644 --- a/flash-attn2/torch-ext/torch_binding.cpp +++ b/flash-attn2/torch-ext/torch_binding.cpp @@ -144,6 +144,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int num_splits) -> Tensor[]"); #if defined(CUDA_KERNEL) ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache); +#elif defined(XPU_KERNEL) + ops.impl("fwd_kvcache", torch::kXPU, &mha_fwd_kvcache); #endif }