Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions flash-attn2/build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,22 @@ src = [
"flash_attn_xpu/src/flash_fwd_hdim256_fix_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim512_fix_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim512_fix_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim32_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim64_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim96_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim128_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim160_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim192_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim256_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_fp16.cpp",
"flash_attn_xpu/src/flash_fwd_hdim512_kvcache_paged_bf16.cpp",
"flash_attn_xpu/src/fmha_bwd_types.hpp",
"flash_attn_xpu/src/fmha_bwd.hpp",
"flash_attn_xpu/src/fmha_bwd_impl.hpp",
Expand Down
280 changes: 280 additions & 0 deletions flash-attn2/flash_attn_xpu/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <c10/xpu/XPUStream.h>
#include <cute/util/compat/device.hpp>
#include <ATen/xpu/XPUGeneratorImpl.h>
#include <sycl/sycl.hpp>
#include <limits>

#include "src/fmha_fwd.hpp"
#include "src/fmha_bwd.hpp"
Expand Down Expand Up @@ -466,6 +468,226 @@ mha_varlen_fwd(
at::Tensor rng_state;
return {out, softmax_lse, S_dmask, rng_state};
}

std::vector<at::Tensor>
mha_fwd_kvcache(
at::Tensor &q,
const at::Tensor &kcache,
const at::Tensor &vcache,
std::optional<const at::Tensor> &k_,
std::optional<const at::Tensor> &v_,
std::optional<const at::Tensor> &seqlens_k_,
std::optional<const at::Tensor> &rotary_cos_,
std::optional<const at::Tensor> &rotary_sin_,
std::optional<const at::Tensor> &cache_batch_idx_,
std::optional<const at::Tensor> &leftpad_k_,
std::optional<at::Tensor> &block_table_,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &out_,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved,
int num_splits) {
auto device_idx = q.device().index();
compat::select_device(device_idx);

TORCH_CHECK(!alibi_slopes_.has_value(),
"FlashAttention KVCache on XPU does not support alibi_slopes.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention KVCache only supports fp16 and bf16 data type");
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key cache must have the same dtype");
TORCH_CHECK(vcache.dtype() == q_dtype, "query and value cache must have the same dtype");

CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
TORCH_CHECK(q.stride(-1) == 1, "Query must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Key cache must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Value cache must have contiguous last dimension");

const bool paged_KV = block_table_.has_value();
at::Tensor block_table;
if (paged_KV) {
TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
}

const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2);

TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256 || head_size_og == 512,
"FlashAttention KVCache only supports head dimension up to 256 or exactly 512");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; }
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }

const int head_size_padded = round_multiple(head_size_og, 32);
const bool needs_padding = (head_size_og != head_size_padded);
const int pad_size = head_size_padded - head_size_og;

auto maybe_pad = [&](const at::Tensor& t) -> at::Tensor {
return needs_padding
? torch::nn::functional::pad(t, torch::nn::functional::PadFuncOptions({0, pad_size}))
: t;
};

at::Tensor q_padded = maybe_pad(ensure_contiguous(q));
at::Tensor kcache_padded = maybe_pad(ensure_contiguous(kcache));
at::Tensor vcache_padded = maybe_pad(ensure_contiguous(vcache));

at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
if (needs_padding) { out = maybe_pad(out); }
} else {
out = torch::zeros_like(q_padded);
}

auto opts = q.options();
auto softmax_lse = torch::full({batch_size, num_heads, seqlen_q},
-std::numeric_limits<float>::infinity(),
opts.dtype(at::kFloat));

// Handle new K/V
at::Tensor k_padded, v_padded;
int seqlen_new = 0;
if (k_.has_value()) {
TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
auto k = k_.value();
auto v = v_.value();
TORCH_CHECK(k.dtype() == q_dtype && v.dtype() == q_dtype);
CHECK_DEVICE(k); CHECK_DEVICE(v);
TORCH_CHECK(k.stride(-1) == 1 && v.stride(-1) == 1);
seqlen_new = k.size(1);
k_padded = maybe_pad(ensure_contiguous(k));
v_padded = maybe_pad(ensure_contiguous(v));
}

at::Tensor seqlens_k;
if (seqlens_k_.has_value()) {
seqlens_k = seqlens_k_.value();
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
}

at::Tensor rotary_cos, rotary_sin;
int rotary_dim = 0;
const bool has_rotary = rotary_cos_.has_value();
if (has_rotary) {
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key/value must also be provided");
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
rotary_cos = ensure_contiguous(rotary_cos_.value());
rotary_sin = ensure_contiguous(rotary_sin_.value());
CHECK_DEVICE(rotary_cos); CHECK_DEVICE(rotary_sin);
rotary_dim = rotary_cos.size(1) * 2;
TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim");
TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
TORCH_CHECK(rotary_cos.scalar_type() == q_dtype && rotary_sin.scalar_type() == q_dtype);
}

at::Tensor cache_batch_idx;
if (cache_batch_idx_.has_value()) {
cache_batch_idx = cache_batch_idx_.value();
CHECK_DEVICE(cache_batch_idx);
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32);
}

at::Tensor leftpad_k;
if (leftpad_k_.has_value()) {
TORCH_CHECK(!paged_KV, "Paged KV and leftpad_k are not supported together");
leftpad_k = leftpad_k_.value();
CHECK_DEVICE(leftpad_k);
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
}

bool fuse_knew = k_.has_value() && seqlen_new > 0;

// Dispatch to kernel. Paged caches are now passed natively (block_table
// routed straight through to the kernel, no host gather).
auto queue = c10::xpu::getCurrentXPUStream(device_idx).queue();
const bool is_local = (window_size_left >= 0);

std::optional<at::Tensor> cache_batch_idx_opt;
if (cache_batch_idx_.has_value()) {
cache_batch_idx_opt = cache_batch_idx;
}

std::optional<at::Tensor> leftpad_k_opt;
if (leftpad_k_.has_value()) {
leftpad_k_opt = leftpad_k;
}

// For paths where new KV is appended in-kernel, pass knew/vnew through.
std::optional<at::Tensor> knew_opt, vnew_opt;
if (fuse_knew) {
knew_opt = k_padded;
vnew_opt = v_padded;
}

std::optional<at::Tensor> block_table_opt;
if (paged_KV) {
block_table_opt = block_table;
}

std::optional<at::Tensor> rotary_cos_opt, rotary_sin_opt;
if (fuse_knew && has_rotary) {
rotary_cos_opt = rotary_cos;
rotary_sin_opt = rotary_sin;
}

cutlass_fmha_fwd_kvcache_impl(
queue,
q_padded, kcache_padded, vcache_padded,
out, softmax_lse,
seqlens_k, cache_batch_idx_opt, leftpad_k_opt,
knew_opt, vnew_opt,
block_table_opt, rotary_cos_opt, rotary_sin_opt,
fuse_knew ? rotary_dim : 0, is_rotary_interleaved, seqlen_k,
softmax_scale, window_size_left, window_size_right,
is_causal, is_local);

// Update seqlens_k after kernel completes (for fused scatter path)
if (fuse_knew) {
seqlens_k = seqlens_k + seqlen_new;
}

if (needs_padding) {
out = out.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)})
.contiguous();
if (out_.has_value()) { out_.value().copy_(out); }
if (fuse_knew) {
// The fused kernel updates the padded cache buffer; publish valid dims back to the user cache.
kcache.copy_(kcache_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}));
vcache.copy_(vcache_padded.index({torch::indexing::Slice(), torch::indexing::Slice(),
torch::indexing::Slice(), torch::indexing::Slice(0, head_size_og)}));
}
}

return {out, softmax_lse};
}

} // namespace FLASH_NAMESPACE

// std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand Down Expand Up @@ -633,4 +855,62 @@ mha_bwd(const torch::Tensor &dout,
gen_,
rng_opt
);
}

std::vector<torch::Tensor>
mha_fwd_kvcache(
const torch::Tensor &q,
const torch::Tensor &kcache,
const torch::Tensor &vcache,
const c10::optional<torch::Tensor> &k_,
const c10::optional<torch::Tensor> &v_,
const c10::optional<torch::Tensor> &seqlens_k_,
const c10::optional<torch::Tensor> &rotary_cos_,
const c10::optional<torch::Tensor> &rotary_sin_,
const c10::optional<torch::Tensor> &cache_batch_idx_,
const c10::optional<torch::Tensor> &leftpad_k_,
const c10::optional<torch::Tensor> &block_table_,
const c10::optional<torch::Tensor> &alibi_slopes_,
const c10::optional<torch::Tensor> &out_,
const double softmax_scale,
bool is_causal,
const int64_t window_size_left,
const int64_t window_size_right,
const double softcap,
bool is_rotary_interleaved,
const int64_t num_splits) {
// Convert c10::optional -> std::optional for the internal API
auto to_std_opt = [](const c10::optional<torch::Tensor>& opt) -> std::optional<at::Tensor> {
return opt.has_value() ? std::optional<at::Tensor>(opt.value()) : std::nullopt;
};
auto to_std_opt_const = [](const c10::optional<torch::Tensor>& opt) -> std::optional<const at::Tensor> {
return opt.has_value() ? std::optional<const at::Tensor>(opt.value()) : std::nullopt;
};

at::Tensor q_mut = q;
auto k_opt = to_std_opt_const(k_);
auto v_opt = to_std_opt_const(v_);
auto seqlens_opt = to_std_opt_const(seqlens_k_);
auto rotary_cos_opt = to_std_opt_const(rotary_cos_);
auto rotary_sin_opt = to_std_opt_const(rotary_sin_);
auto cache_batch_idx_opt = to_std_opt_const(cache_batch_idx_);
auto leftpad_k_opt = to_std_opt_const(leftpad_k_);
auto block_table_opt = to_std_opt(block_table_);
auto alibi_opt = to_std_opt(alibi_slopes_);
auto out_opt = to_std_opt(out_);

return FLASH_NAMESPACE::mha_fwd_kvcache(
q_mut, kcache, vcache,
k_opt, v_opt, seqlens_opt,
rotary_cos_opt, rotary_sin_opt,
cache_batch_idx_opt, leftpad_k_opt,
block_table_opt, alibi_opt, out_opt,
static_cast<float>(softmax_scale),
is_causal,
static_cast<int>(window_size_left),
static_cast<int>(window_size_right),
static_cast<float>(softcap),
is_rotary_interleaved,
static_cast<int>(num_splits)
);
}
42 changes: 42 additions & 0 deletions flash-attn2/flash_attn_xpu/src/collective/fmha_fwd_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,48 @@ namespace cutlass::fmha::collective {

using namespace cute;

template <typename Element, typename RotaryElement>
CUTLASS_DEVICE Element apply_rotary_scalar(
Element x,
Element x_pair,
const RotaryElement* cos,
const RotaryElement* sin,
int position,
int dim,
int rotary_dim,
bool interleaved) {
if (rotary_dim == 0 || dim >= rotary_dim) {
return x;
}

int half_rotary = rotary_dim / 2;
int cos_sin_idx = interleaved ? dim / 2
: (dim < half_rotary ? dim : dim - half_rotary);
bool is_second = interleaved ? (dim % 2) : (dim >= half_rotary);

float x_f = static_cast<float>(x);
float x_pair_f = static_cast<float>(x_pair);
float c = static_cast<float>(cos[position * half_rotary + cos_sin_idx]);
float s = static_cast<float>(sin[position * half_rotary + cos_sin_idx]);
float rotated = is_second ? x_pair_f * s + x_f * c
: x_f * c - x_pair_f * s;
return static_cast<Element>(rotated);
}

CUTLASS_DEVICE int rotary_pair_dim(
int dim,
int rotary_dim,
bool interleaved) {
if (dim >= rotary_dim) {
return dim;
}
if (interleaved) {
return dim ^ 1;
}
int half_rotary = rotary_dim / 2;
return dim < half_rotary ? dim + half_rotary : dim - half_rotary;
}

/////////////////////////////////////////////////////////////////////////////////////////////////
//
// FMHAFwdMainloopTraits: common type aliases derived from TiledMMA / VTiles.
Expand Down
Loading