diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index 89a1e42c..58647e3e 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -1,6 +1,6 @@ [general] name = "flash-attn2" -version = 1 +version = 3 license = "BSD-3-Clause" backends = [ "cpu", @@ -12,8 +12,10 @@ backends = [ repo-id = "kernels-community/flash-attn2" [torch] +stable-abi = "2.10" src = [ "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding_stable.cpp", "torch-ext/torch_binding.h", ] @@ -180,6 +182,7 @@ depends = [ src = [ "flash_attn/flash_api.cpp", "flash_attn/src/philox_unpack.cuh", + "flash_attn/src/cuda_check.h", "flash_attn/src/namespace_config.h", "flash_attn/src/hardware_info.h", "flash_attn/src/flash.h", diff --git a/flash-attn2/flash_attn/flash_api.cpp b/flash-attn2/flash_attn/flash_api.cpp index 0b618bf4..5cc5f924 100644 --- a/flash-attn2/flash_attn/flash_api.cpp +++ b/flash-attn2/flash_attn/flash_api.cpp @@ -2,13 +2,13 @@ * Copyright (c) 2024, Tri Dao. ******************************************************************************/ -// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. -// #include -#include -#include -#include -#include // For at::Generator and at::PhiloxCudaState -#include "src/philox_unpack.cuh" // For at::cuda::philox::unpack +// Flash-Attention 2 CUDA bindings, ported to the Torch stable ABI. Dropout is +// unsupported here: the CUDA RNG generator is not exposed by the stable ABI. + +// CUDA vector types (int4, float4, ...) must be visible before cutlass headers, +// which use them at namespace scope. flash_api.cpp is host-compiled (g++), so +// nvcc's implicit is not available here. +#include #include @@ -17,9 +17,46 @@ #include "src/flash.h" #include "src/static_switch.h" -#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#include +#include +#include +#include +#include + +// Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h +extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); + +#include +#include + +#include +#include +#include +#include +#include +#include + +using torch::stable::Tensor; +namespace tsa = torch::stable::accelerator; + +namespace { +inline tsa::DeviceGuard make_device_guard(const Tensor &t) { + return tsa::DeviceGuard(static_cast(t.get_device())); +} +} // namespace + +#define CHECK_DEVICE(x) STD_TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) \ + do { \ + auto __expected = std::vector{__VA_ARGS__}; \ + STD_TORCH_CHECK(x.dim() == static_cast(__expected.size()), \ + #x " must have ", __expected.size(), " dimensions, got ", x.dim()); \ + for (size_t __i = 0; __i < __expected.size(); ++__i) { \ + STD_TORCH_CHECK(x.size(__i) == __expected[__i], \ + #x " has wrong shape at dim ", __i); \ + } \ + } while (0) +#define CHECK_CONTIGUOUS(x) STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace FLASH_NAMESPACE { @@ -35,10 +72,10 @@ void set_params_fprop(Flash_fwd_params ¶ms, const size_t d, const size_t d_rounded, // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - at::Tensor out, + const Tensor q, + const Tensor k, + const Tensor v, + Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_k, @@ -55,7 +92,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Reset the parameters params = {}; - params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_bf16 = q.scalar_type() == torch::headeronly::ScalarType::BFloat16; // Set the pointers and strides. params.q_ptr = q.data_ptr(); @@ -107,7 +144,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, // Set the different scale values. #ifdef FLASHATTENTION_DISABLE_SOFTCAP - TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + STD_TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); #endif if (softcap > 0.0) { params.softcap = softmax_scale / softcap; @@ -129,9 +166,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - TORCH_CHECK(p_dropout < 1.f); + STD_TORCH_CHECK(p_dropout < 1.f); #ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + STD_TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. @@ -144,14 +181,14 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_right = window_size_right; #ifdef FLASHATTENTION_DISABLE_LOCAL - TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + STD_TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), "This flash attention build does not support local attention."); #endif params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K - TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + STD_TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif params.unpadded_lse = unpadded_lse; @@ -170,14 +207,14 @@ void set_params_dgrad(Flash_bwd_params ¶ms, const size_t d, const size_t d_rounded, // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - const at::Tensor out, - const at::Tensor dout, - at::Tensor dq, - at::Tensor dk, - at::Tensor dv, + const Tensor q, + const Tensor k, + const Tensor v, + const Tensor out, + const Tensor dout, + Tensor dq, + Tensor dk, + Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *dq_accum_d, @@ -296,10 +333,10 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n return 1; } -std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, +std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, const float p_dropout, - const int num_splits, const int num_sm, struct c10::TensorOptions opts) { + const int num_splits, const int num_sm, const Tensor &splitkv_ref) { // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); @@ -308,8 +345,8 @@ std::tuple set_params_splitkv(Flash_fwd_params ¶ms, // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; params.num_splits = num_splits; - at::Tensor softmax_lse_accum; - at::Tensor out_accum; + Tensor softmax_lse_accum; + Tensor out_accum; if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { @@ -317,28 +354,30 @@ std::tuple set_params_splitkv(Flash_fwd_params ¶ms, params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); } if (params.num_splits > 1) { - softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + softmax_lse_accum = torch::stable::new_empty(splitkv_ref, {params.num_splits, batch_size, num_heads, max_seqlen_q}, torch::headeronly::ScalarType::Float); + out_accum = torch::stable::new_empty(splitkv_ref, {params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, torch::headeronly::ScalarType::Float); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } - TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + STD_TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); } return std::make_tuple(softmax_lse_accum, out_accum); } -void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi_slopes_, int batch_size, int num_heads){ +void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi_slopes_, int batch_size, int num_heads){ #ifdef FLASHATTENTION_DISABLE_ALIBI - TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); + STD_TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); params.alibi_slopes_ptr = nullptr; #else if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); - TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); + STD_TORCH_CHECK(alibi_slopes.scalar_type() == torch::headeronly::ScalarType::Float, "ALiBi slopes must have dtype fp32"); CHECK_DEVICE(alibi_slopes); - TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); - TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); + STD_TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + STD_TORCH_CHECK((alibi_slopes.dim() == 1 && alibi_slopes.size(0) == num_heads) || + (alibi_slopes.dim() == 2 && alibi_slopes.size(0) == batch_size && alibi_slopes.size(1) == num_heads), + "ALiBi slopes must have shape (num_heads,) or (batch_size, num_heads)"); params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } else { @@ -347,41 +386,42 @@ void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi #endif } -std::vector -mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads +std::vector +mha_fwd(Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float softcap, - const bool return_softmax, - std::optional gen_) { + const bool return_softmax) { // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + auto device_guard = make_device_guard(q); auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + STD_TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + auto q_dtype = q.scalar_type(); + STD_TORCH_CHECK(q_dtype == torch::headeronly::ScalarType::Half || q_dtype == torch::headeronly::ScalarType::BFloat16, "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + STD_TORCH_CHECK(k.scalar_type() == q_dtype, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_dtype, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - const auto sizes = q.sizes(); + std::vector sizes; + sizes.reserve(q.dim()); + for (int64_t __d = 0; __d < q.dim(); ++__d) sizes.push_back(q.size(__d)); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; @@ -389,12 +429,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + STD_TORCH_CHECK(batch_size > 0, "batch size must be positive"); + STD_TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + STD_TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + if (softcap > 0.f) { STD_TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } @@ -408,7 +448,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + q = torch::stable::transpose(torch::stable::reshape(q, {batch_size, num_heads_k, ngroups, head_size}), 1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } @@ -417,18 +457,18 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - at::Tensor out; + Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + STD_TORCH_CHECK(out.scalar_type() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); + out = torch::stable::transpose(torch::stable::reshape(out, {batch_size, num_heads_k, ngroups, head_size}), 1, 2); } } else { - out = torch::empty_like(q); + out = torch::stable::empty_like(q); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -436,17 +476,16 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - auto opts = q.options(); - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor p; + auto softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, torch::headeronly::ScalarType::Float); + Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + STD_TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::stable::new_empty(q, { batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }); } else { - p = torch::empty({ 0 }, opts); + p = torch::stable::new_empty(q, { 0 }); } Flash_fwd_params params; @@ -470,58 +509,55 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult ); // Keep references to these tensors to extend their lifetime - at::Tensor softmax_lse_accum, out_accum; + Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); + head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), q); // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto rng_state = torch::stable::new_empty(q, {2}, torch::headeronly::ScalarType::Long); // Forward kernel will populate memory with the seed and offset. params.rng_state = reinterpret_cast(rng_state.data_ptr()); - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } + STD_TORCH_CHECK(p_dropout == 0.0, "flash-attn2 stable-ABI build does not support dropout (p_dropout > 0)."); + (void)counter_offset; set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream_device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(stream_device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); run_mha_fwd(params, stream); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); + torch::stable::fill_(out, 0); + torch::stable::fill_(softmax_lse, std::numeric_limits::infinity()); } if (seqlenq_ngroups_swapped) { - out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); - q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + out = torch::stable::reshape(torch::stable::transpose(out, 1, 2), {batch_size, 1, num_heads_k * seqlen_q, head_size}); + q = torch::stable::reshape(torch::stable::transpose(q, 1, 2), {batch_size, 1, num_heads_k * seqlen_q, head_size}); + softmax_lse = torch::stable::reshape(softmax_lse, {batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse, p, rng_state}; } -std::vector -mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - std::optional &leftpad_k_, // batch_size - std::optional &block_table_, // batch_size x max_num_blocks_per_seq - std::optional &alibi_slopes_, // num_heads or b x num_heads +std::vector +mha_varlen_fwd(Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const Tensor &cu_seqlens_q, // b+1 + const Tensor &cu_seqlens_k, // b+1 + std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, const float p_dropout, @@ -531,56 +567,57 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_left, int window_size_right, const float softcap, - const bool return_softmax, - std::optional gen_) { + const bool return_softmax) { // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + auto device_guard = make_device_guard(q); auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + STD_TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + auto q_dtype = q.scalar_type(); + STD_TORCH_CHECK(q_dtype == torch::headeronly::ScalarType::Half || q_dtype == torch::headeronly::ScalarType::BFloat16, "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + STD_TORCH_CHECK(k.scalar_type() == q_dtype, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_dtype, "query and value must have the same dtype"); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype int32"); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype int32"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); - at::Tensor block_table; + Tensor block_table; const bool paged_KV = block_table_.has_value(); if (paged_KV) { block_table = block_table_.value(); CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + STD_TORCH_CHECK(block_table.scalar_type() == torch::headeronly::ScalarType::Int, "block_table must have dtype torch.int32"); + STD_TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); - const auto sizes = q.sizes(); + std::vector sizes; + sizes.reserve(q.dim()); + for (int64_t __d = 0; __d < q.dim(); ++__d) sizes.push_back(q.size(__d)); const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + if (softcap > 0.f) { STD_TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + STD_TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } @@ -592,18 +629,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + q = torch::stable::reshape(torch::stable::transpose(torch::stable::reshape(q, {batch_size, num_heads_k, ngroups, head_size}), 1, 2), {batch_size * ngroups, num_heads_k, head_size}); max_seqlen_q = ngroups; num_heads = num_heads_k; cu_seqlens_q_d = nullptr; } - const int total_q = q.sizes()[0]; + const int total_q = q.size(0); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + STD_TORCH_CHECK(batch_size > 0, "batch size must be positive"); + STD_TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + STD_TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= max_seqlen_k) { window_size_left = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; } @@ -623,24 +660,24 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_SHAPE(cu_seqlens_k, batch_size + 1); if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); - TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); - TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); - TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + STD_TORCH_CHECK(seqused_k_.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); + STD_TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + STD_TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); CHECK_SHAPE(seqused_k_, batch_size); } - at::Tensor out; + Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + STD_TORCH_CHECK(out.scalar_type() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, sizes[0], sizes[1], head_size); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); + out = torch::stable::reshape(torch::stable::transpose(torch::stable::reshape(out, {batch_size, num_heads_k, ngroups, head_size}), 1, 2), {batch_size * ngroups, num_heads_k, head_size}); } } else { - out = torch::empty_like(q); + out = torch::stable::empty_like(q); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -648,22 +685,21 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - auto opts = q.options(); - auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - at::Tensor p; + auto softmax_lse = torch::stable::new_empty(q, {num_heads, total_q}, torch::headeronly::ScalarType::Float); + Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); - p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); + STD_TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + p = torch::stable::new_empty(q, { batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }); } else { - p = torch::empty({ 0 }, opts); + p = torch::stable::new_empty(q, { 0 }); } if (zero_tensors) { - out.zero_(); - softmax_lse.fill_(-std::numeric_limits::infinity()); - if (return_softmax) {p.zero_();} + torch::stable::fill_(out, 0); + torch::stable::fill_(softmax_lse, -std::numeric_limits::infinity()); + if (return_softmax) {torch::stable::fill_(p, 0);} } Flash_fwd_params params; @@ -689,26 +725,26 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s params.total_q = total_q; if (paged_KV) { - params.block_table = block_table.data_ptr(); + params.block_table = static_cast(block_table.data_ptr()); params.block_table_batch_stride = block_table.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); } params.page_block_size = page_block_size; // Keep references to these tensors to extend their lifetime - at::Tensor softmax_lse_accum, out_accum; + Tensor softmax_lse_accum, out_accum; if (seqlenq_ngroups_swapped) { // Only apply split-k for decoding std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, - p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); + p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), q); } if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + STD_TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); @@ -719,36 +755,33 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto rng_state = torch::stable::new_empty(q, {2}, torch::headeronly::ScalarType::Long); // Forward kernel will populate memory with the seed and offset. params.rng_state = reinterpret_cast(rng_state.data_ptr()); - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } + STD_TORCH_CHECK(p_dropout == 0.0, "flash-attn2 stable-ABI build does not support dropout (p_dropout > 0)."); + (void)counter_offset; set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (max_seqlen_k > 0) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream_device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(stream_device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); run_mha_fwd(params, stream, paged_KV); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. - out.zero_(); - softmax_lse.fill_(std::numeric_limits::infinity()); + torch::stable::fill_(out, 0); + torch::stable::fill_(softmax_lse, std::numeric_limits::infinity()); } if (seqlenq_ngroups_swapped) { - int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; - int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; - out = out.reshape(size_before).transpose(1, 2).reshape(size_after); - q = q.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); + std::vector size_before = {batch_size, max_seqlen_q, num_heads_k, head_size}; + std::vector size_after = {batch_size, num_heads_k * max_seqlen_q, head_size}; + out = torch::stable::reshape(torch::stable::transpose(torch::stable::reshape(out, size_before), 1, 2), size_after); + q = torch::stable::reshape(torch::stable::transpose(torch::stable::reshape(q, size_before), 1, 2), size_after); + softmax_lse = torch::stable::reshape(softmax_lse, {num_heads * max_seqlen_q, batch_size}); } return {out, softmax_lse, p, rng_state}; @@ -764,17 +797,17 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { }); } -std::vector -mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x seqlen_q - std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads +std::vector +mha_bwd(const Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) + const Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const Tensor &softmax_lse, // b x h x seqlen_q + std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, @@ -782,42 +815,46 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl int window_size_right, const float softcap, const bool deterministic, - std::optional gen_, - std::optional &rng_state) { + std::optional &rng_state) { #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash attention build does not support backward."); + STD_TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + auto device_guard = make_device_guard(q); auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + STD_TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream_device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(stream_device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + auto q_dtype = q.scalar_type(); + STD_TORCH_CHECK(q_dtype == torch::headeronly::ScalarType::Half || q_dtype == torch::headeronly::ScalarType::BFloat16, "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + STD_TORCH_CHECK(k.scalar_type() == q_dtype, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_dtype, "query and value must have the same dtype"); + STD_TORCH_CHECK(out.scalar_type() == q_dtype, "query and out must have the same dtype"); + STD_TORCH_CHECK(dout.scalar_type() == q_dtype, "query and dout must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + STD_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); - const auto sizes = q.sizes(); + std::vector sizes; + sizes.reserve(q.dim()); + for (int64_t __d = 0; __d < q.dim(); ++__d) sizes.push_back(q.size(__d)); const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; @@ -825,17 +862,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + STD_TORCH_CHECK(batch_size > 0, "batch size must be positive"); + STD_TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + STD_TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + if (softcap > 0.f) { STD_TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } @@ -846,58 +883,57 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); - at::Tensor dq, dk, dv; + Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + STD_TORCH_CHECK(dq.scalar_type() == q_dtype, "dq must have the same dtype as q"); CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + STD_TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); } else { - dq = torch::empty_like(q); + dq = torch::stable::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + STD_TORCH_CHECK(dk.scalar_type() == q_dtype, "dk must have the same dtype as q"); CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + STD_TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); } else { - dk = torch::empty_like(k); + dk = torch::stable::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + STD_TORCH_CHECK(dv.scalar_type() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + STD_TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); } else { - dv = torch::empty_like(v); + dv = torch::stable::empty_like(v); } // bool loop = seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; - auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - at::Tensor dk_accum, dv_accum; + auto softmax_d = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, torch::headeronly::ScalarType::Float); + Tensor dq_accum; + Tensor dk_accum, dv_accum; if (loop) { if (!deterministic) { - dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = torch::stable::new_empty(q, {batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, torch::headeronly::ScalarType::Float); } else { const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = torch::stable::new_zeros(q, {nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, torch::headeronly::ScalarType::Float); } - // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, torch::headeronly::ScalarType::Float); + // dv_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, torch::headeronly::ScalarType::Float); } - at::Tensor dk_expanded, dv_expanded; + Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + dk_expanded = torch::stable::new_empty(q, {batch_size, seqlen_k, num_heads, head_size}); + dv_expanded = torch::stable::new_empty(q, {batch_size, seqlen_k, num_heads, head_size}); } else { dk_expanded = dk; dv_expanded = dv; @@ -933,21 +969,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl auto launch = &run_mha_bwd; - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; + (void)counter_offset; if ( rng_state.has_value() ) { params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); - } else if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); + } else { + STD_TORCH_CHECK(!is_dropout, "flash-attn2 stable-ABI build does not support dropout backward without a saved rng_state."); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); @@ -956,33 +985,33 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); - softmax_d.zero_(); + torch::stable::fill_(dk_expanded, 0); + torch::stable::fill_(dv_expanded, 0); + torch::stable::fill_(softmax_d, 0); } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + torch::stable::sum_out(dk, torch::stable::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); + torch::stable::sum_out(dv, torch::stable::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } return { dq, dk, dv, softmax_d }; } -std::vector -mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size - const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp - std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens_q, // b+1 - const at::Tensor &cu_seqlens_k, // b+1 - std::optional &alibi_slopes_, // num_heads or b x num_heads +std::vector +mha_varlen_bwd(const Tensor &dout, // total_q x num_heads, x head_size + const Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const Tensor &out, // total_q x num_heads x head_size + const Tensor &softmax_lse, // h x total_q, softmax logsumexp + std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const Tensor &cu_seqlens_q, // b+1 + const Tensor &cu_seqlens_k, // b+1 + std::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop @@ -993,47 +1022,51 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size int window_size_right, const float softcap, const bool deterministic, - std::optional gen_, - std::optional &rng_state) { + std::optional &rng_state) { #ifdef FLASHATTENTION_DISABLE_BACKWARD - TORCH_CHECK(false, "This flash attention build does not support backward."); + STD_TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + auto device_guard = make_device_guard(q); auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + STD_TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); bool is_dropout = p_dropout > 0.0; - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream_device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(stream_device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + auto q_dtype = q.scalar_type(); + STD_TORCH_CHECK(q_dtype == torch::headeronly::ScalarType::Half || q_dtype == torch::headeronly::ScalarType::BFloat16, "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); - TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); + STD_TORCH_CHECK(k.scalar_type() == q_dtype, "query and key must have the same dtype"); + STD_TORCH_CHECK(v.scalar_type() == q_dtype, "query and value must have the same dtype"); + STD_TORCH_CHECK(out.scalar_type() == q_dtype, "query and out must have the same dtype"); + STD_TORCH_CHECK(dout.scalar_type() == q_dtype, "query and dout must have the same dtype"); + STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype int32"); + STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype int32"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); - TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + STD_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); - const auto sizes = q.sizes(); + std::vector sizes; + sizes.reserve(q.dim()); + for (int64_t __d = 0; __d < q.dim(); ++__d) sizes.push_back(q.size(__d)); const int total_q = sizes[0]; const int batch_size = cu_seqlens_q.numel() - 1; @@ -1041,11 +1074,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const int head_size = sizes[2]; const int total_k = k.size(0); const int num_heads_k = k.size(1); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); - TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + STD_TORCH_CHECK(batch_size > 0, "batch size must be positive"); + STD_TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + STD_TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (softcap > 0.f) { STD_TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); @@ -1063,42 +1096,41 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); - at::Tensor dq, dk, dv; + Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); - TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + STD_TORCH_CHECK(dq.scalar_type() == q_dtype, "dq must have the same dtype as q"); CHECK_DEVICE(dq); - TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + STD_TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, total_q, num_heads, head_size); } else { - dq = torch::empty_like(q); + dq = torch::stable::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); - TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + STD_TORCH_CHECK(dk.scalar_type() == q_dtype, "dk must have the same dtype as q"); CHECK_DEVICE(dk); - TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + STD_TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, total_k, num_heads_k, head_size); } else { - dk = torch::empty_like(k); + dk = torch::stable::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); - TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + STD_TORCH_CHECK(dv.scalar_type() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); - TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + STD_TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { - dv = torch::empty_like(v); + dv = torch::stable::empty_like(v); } // bool loop = max_seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; - auto opts = q.options(); - auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; + auto softmax_d = torch::stable::new_empty(q, {num_heads, total_q + 128 * batch_size}, torch::headeronly::ScalarType::Float); + Tensor dq_accum; if (loop) { // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) // because that would be too large if there is a very long sequence and the rest of the sequences are short. @@ -1110,27 +1142,27 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // allowed to do. So we won't have to do any bound checking, and performance should stay the same. // Same holds for softmax_d, since LSE is stored in unpadded format. if (!deterministic) { - dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = torch::stable::new_empty(q, {total_q + 128 * batch_size, num_heads, head_size_rounded}, torch::headeronly::ScalarType::Float); } else { const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); - dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = torch::stable::new_zeros(q, {nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, torch::headeronly::ScalarType::Float); } } - at::Tensor dk_expanded, dv_expanded; + Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); - dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); + dk_expanded = torch::stable::new_empty(q, {total_k, num_heads, head_size}); + dv_expanded = torch::stable::new_empty(q, {total_k, num_heads, head_size}); } else { dk_expanded = dk; dv_expanded = dv; } if( zero_tensors ) { - dq.zero_(); - dk_expanded.zero_(); - dv_expanded.zero_(); - softmax_d.zero_(); + torch::stable::fill_(dq, 0); + torch::stable::fill_(dk_expanded, 0); + torch::stable::fill_(dv_expanded, 0); + torch::stable::fill_(softmax_d, 0); } Flash_bwd_params params; @@ -1162,21 +1194,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size auto launch = &run_mha_bwd; - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; + (void)counter_offset; if ( rng_state.has_value() ) { params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); - } else if( is_dropout ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); + } else { + STD_TORCH_CHECK(!is_dropout, "flash-attn2 stable-ABI build does not support dropout backward without a saved rng_state."); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); @@ -1185,34 +1210,34 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk_expanded.zero_(); - dv_expanded.zero_(); - softmax_d.zero_(); + torch::stable::fill_(dk_expanded, 0); + torch::stable::fill_(dv_expanded, 0); + torch::stable::fill_(softmax_d, 0); } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); - at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + torch::stable::sum_out(dk, torch::stable::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); + torch::stable::sum_out(dv, torch::stable::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); } return { dq, dk, dv, softmax_d }; } -std::vector -mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size - std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size - std::optional &seqlens_k_, // batch_size - std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional &cache_batch_idx_, // indices to index into the KV cache - std::optional &leftpad_k_, // batch_size - std::optional &block_table_, // batch_size x max_num_blocks_per_seq - std::optional &alibi_slopes_, // num_heads or batch_size x num_heads - std::optional &out_, // batch_size x seqlen_q x num_heads x head_size +std::vector +mha_fwd_kvcache(Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size + std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size + std::optional &seqlens_k_, // batch_size + std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &cache_batch_idx_, // indices to index into the KV cache + std::optional &leftpad_k_, // batch_size + std::optional &block_table_, // batch_size x max_num_blocks_per_seq + std::optional &alibi_slopes_, // num_heads or batch_size x num_heads + std::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, int window_size_left, @@ -1223,35 +1248,37 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ) { // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; + auto device_guard = make_device_guard(q); auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + STD_TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + auto q_dtype = q.scalar_type(); + STD_TORCH_CHECK(q_dtype == torch::headeronly::ScalarType::Half || q_dtype == torch::headeronly::ScalarType::BFloat16, "FlashAttention only support fp16 and bf16 data type"); - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); - TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); + STD_TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); + STD_TORCH_CHECK(vcache.scalar_type() == q_dtype, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + STD_TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - at::Tensor block_table; + Tensor block_table; const bool paged_KV = block_table_.has_value(); if (paged_KV) { - TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + STD_TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); block_table = block_table_.value(); CHECK_DEVICE(block_table); - TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); - TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + STD_TORCH_CHECK(block_table.scalar_type() == torch::headeronly::ScalarType::Int, "block_table must have dtype torch.int32"); + STD_TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } - const auto sizes = q.sizes(); + std::vector sizes; + sizes.reserve(q.dim()); + for (int64_t __d = 0; __d < q.dim(); ++__d) sizes.push_back(q.size(__d)); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; @@ -1261,13 +1288,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int page_block_size = !paged_KV ? 1 : kcache.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + STD_TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; - TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + STD_TORCH_CHECK(batch_size > 0, "batch size must be positive"); + STD_TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } @@ -1278,7 +1305,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; - q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + q = torch::stable::transpose(torch::stable::reshape(q, {batch_size, num_heads_k, ngroups, head_size_og}), 1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } @@ -1296,27 +1323,27 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } - at::Tensor q_padded, kcache_padded, vcache_padded; + Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { - q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + q_padded = torch::stable::pad(q, {0, 8 - head_size_og % 8}); + kcache_padded = torch::stable::pad(kcache, {0, 8 - head_size_og % 8}); + vcache_padded = torch::stable::pad(vcache, {0, 8 - head_size_og % 8}); } else { q_padded = q; kcache_padded = kcache; vcache_padded = vcache; } - at::Tensor out; + Tensor out; if (out_.has_value()) { out = out_.value(); - TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + STD_TORCH_CHECK(out.scalar_type() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); - TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); - if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + if (head_size_og % 8 != 0) { out = torch::stable::empty_like(q_padded); } } else { - out = torch::empty_like(q_padded); + out = torch::stable::empty_like(q_padded); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -1325,9 +1352,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - auto opts = q.options(); - auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + auto softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, torch::headeronly::ScalarType::Float); Flash_fwd_params params; set_params_fprop(params, @@ -1349,24 +1375,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he softcap ); - at::Tensor k, v, k_padded, v_padded; + Tensor k, v, k_padded, v_padded; if (k_.has_value()) { - TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); - TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); - TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); + STD_TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); + STD_TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); + STD_TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); k = k_.value(); v = v_.value(); - TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); - TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); + STD_TORCH_CHECK(k.scalar_type() == q_dtype, "Key must have the same dtype as query"); + STD_TORCH_CHECK(v.scalar_type() == q_dtype, "Value must have the same dtype as query"); CHECK_DEVICE(k); CHECK_DEVICE(v); - TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); + STD_TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); + STD_TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); int seqlen_knew = k.size(1); CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); if (head_size_og % 8 != 0) { - k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); - v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::stable::pad(k, {0, 8 - head_size_og % 8}); + v_padded = torch::stable::pad(v, {0, 8 - head_size_og % 8}); } else { k_padded = k; v_padded = v; @@ -1385,7 +1411,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he if (seqlens_k_.has_value()) { auto seqlens_k = seqlens_k_.value(); - TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + STD_TORCH_CHECK(seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); @@ -1393,9 +1419,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); if (leftpad_k_.has_value()) { - TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + STD_TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); @@ -1403,24 +1429,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } if (rotary_cos_.has_value()) { - TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); + STD_TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + STD_TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + STD_TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); const int seqlen_ro = rotary_cos.size(0); - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + STD_TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); CHECK_CONTIGUOUS(rotary_cos); - TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + STD_TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + STD_TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); CHECK_CONTIGUOUS(rotary_sin); - TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); + STD_TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; @@ -1432,18 +1458,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he auto cache_batch_idx = cache_batch_idx_.value(); CHECK_DEVICE(cache_batch_idx); CHECK_CONTIGUOUS(cache_batch_idx); - TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); + STD_TORCH_CHECK(cache_batch_idx.scalar_type() == torch::headeronly::ScalarType::Int, "cache_batch_idx must have dtype int32"); params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); } // Keep references to these tensors to extend their lifetime - at::Tensor softmax_lse_accum, out_accum; + Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts); + head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), q); if (paged_KV) { - params.block_table = block_table.data_ptr(); + params.block_table = static_cast(block_table.data_ptr()); params.block_table_batch_stride = block_table.stride(0); } params.page_block_size = page_block_size; @@ -1451,290 +1477,203 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he set_params_alibi(params, alibi_slopes_, batch_size, num_heads); - auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto stream_device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); + void* stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(stream_device_idx, &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, // or paged KV cache run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); if (head_size_og % 8 != 0) { - out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); - if (out_.has_value()) { out_.value().copy_(out); } + out = torch::stable::narrow(out, -1, 0, head_size_og); + if (out_.has_value()) { Tensor o_ = out_.value(); torch::stable::copy_(o_, out); } if (k_.has_value()) { // It's expensive to copy the KV cache here for the case where head size not divisible by 8, // but we don't expect to get this case in practice. This is just so that the code works for that case. - kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); - vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); + Tensor kcache_mut = kcache; + Tensor vcache_mut = vcache; + torch::stable::copy_(kcache_mut, torch::stable::narrow(kcache_padded, -1, 0, head_size_og)); + torch::stable::copy_(vcache_mut, torch::stable::narrow(vcache_padded, -1, 0, head_size_og)); } } if (seqlenq_ngroups_swapped) { - out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); + out = torch::stable::reshape(torch::stable::transpose(out, 1, 2), {batch_size, 1, num_heads_k * seqlen_q, head_size_og}); + softmax_lse = torch::stable::reshape(softmax_lse, {batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse}; } + } // namespace FLASH_NAMESPACE -// std::tuple -std::vector -mha_fwd( - torch::Tensor &q, - const torch::Tensor &k, - const torch::Tensor &v, - c10::optional out_, - c10::optional alibi_slopes_, - const double p_dropout, - const double softmax_scale, - bool is_causal, - const int64_t window_size_left, - const int64_t window_size_right, - const double softcap, - const bool return_softmax, - c10::optional gen_) { - return FLASH_NAMESPACE::mha_fwd( - q, - k, - v, - out_, - alibi_slopes_, - static_cast(p_dropout), - static_cast(softmax_scale), - is_causal, - static_cast(window_size_left), - static_cast(window_size_right), - static_cast(softcap), - return_softmax, - gen_ - ); +// Boxed entry points for the stable ABI: unbox the stack, dispatch, box the +// result tuple back. The `Generator? gen_` schema slot is accepted but ignored. + +void boxed_mha_fwd(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor q = to(stack[0]); + Tensor k = to(stack[1]); + Tensor v = to(stack[2]); + std::optional out_ = to>(stack[3]); + std::optional alibi_slopes_ = to>(stack[4]); + double p_dropout = to(stack[5]); + double softmax_scale = to(stack[6]); + bool is_causal = to(stack[7]); + int64_t window_size_left = to(stack[8]); + int64_t window_size_right = to(stack[9]); + double softcap = to(stack[10]); + bool return_softmax = to(stack[11]); + // stack[12]: Generator? gen_ (ignored under the stable ABI) + + auto out = FLASH_NAMESPACE::mha_fwd( + q, k, v, out_, alibi_slopes_, + static_cast(p_dropout), static_cast(softmax_scale), is_causal, + static_cast(window_size_left), static_cast(window_size_right), + static_cast(softcap), return_softmax); + + stack[0] = from(out[0]); + stack[1] = from(out[1]); + stack[2] = from(out[2]); + stack[3] = from(out[3]); } -std::vector -mha_varlen_fwd( - torch::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const torch::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const torch::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - std::optional out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i - const torch::Tensor &cu_seqlens_q, // b+1 - const torch::Tensor &cu_seqlens_k, // b+1 - std::optional seqused_k, // b. If given, only this many elements of each batch element's keys are used. - std::optional leftpad_k_, // batch_size - std::optional block_table_, // batch_size x max_num_blocks_per_seq - std::optional alibi_slopes_, // num_heads or b x num_heads - int64_t max_seqlen_q, - const int64_t max_seqlen_k, - const double p_dropout, - const double softmax_scale, - const bool zero_tensors, - bool is_causal, - int64_t window_size_left, - int64_t window_size_right, - const double softcap, - const bool return_softmax, - std::optional gen_) { - return FLASH_NAMESPACE::mha_varlen_fwd( - const_cast(q), - k, - v, - out_, - cu_seqlens_q, - cu_seqlens_k, - seqused_k, - reinterpret_cast&>(leftpad_k_), - block_table_, - alibi_slopes_, - static_cast(max_seqlen_q), - static_cast(max_seqlen_k), - static_cast(p_dropout), - static_cast(softmax_scale), - zero_tensors, - is_causal, - static_cast(window_size_left), - static_cast(window_size_right), - static_cast(softcap), - return_softmax, - gen_ - ); +void boxed_mha_varlen_fwd(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor q = to(stack[0]); + Tensor k = to(stack[1]); + Tensor v = to(stack[2]); + std::optional out_ = to>(stack[3]); + Tensor cu_seqlens_q = to(stack[4]); + Tensor cu_seqlens_k = to(stack[5]); + std::optional seqused_k = to>(stack[6]); + std::optional leftpad_k_ = to>(stack[7]); + std::optional block_table_ = to>(stack[8]); + std::optional alibi_slopes_ = to>(stack[9]); + int64_t max_seqlen_q = to(stack[10]); + int64_t max_seqlen_k = to(stack[11]); + double p_dropout = to(stack[12]); + double softmax_scale = to(stack[13]); + bool zero_tensors = to(stack[14]); + bool is_causal = to(stack[15]); + int64_t window_size_left = to(stack[16]); + int64_t window_size_right = to(stack[17]); + double softcap = to(stack[18]); + bool return_softmax = to(stack[19]); + // stack[20]: Generator? gen_ (ignored) + + auto out = FLASH_NAMESPACE::mha_varlen_fwd( + q, k, v, out_, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k_, block_table_, + alibi_slopes_, static_cast(max_seqlen_q), static_cast(max_seqlen_k), + static_cast(p_dropout), static_cast(softmax_scale), zero_tensors, is_causal, + static_cast(window_size_left), static_cast(window_size_right), + static_cast(softcap), return_softmax); + + stack[0] = from(out[0]); + stack[1] = from(out[1]); + stack[2] = from(out[2]); + stack[3] = from(out[3]); } -std::vector -mha_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const torch::Tensor &softmax_lse, // b x h x seqlen_q - const c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - const c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - const c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const double p_dropout, // probability to drop - const double softmax_scale, - const bool is_causal, - const int64_t window_size_left, - const int64_t window_size_right, - const double softcap, - const bool deterministic, - c10::optional gen_, - const c10::optional &rng_state) { - - auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator()); - - // Prepare the optional arguments as non-const references. - std::optional dq = dq_.has_value() ? std::optional(const_cast(dq_.value())) : std::nullopt; - std::optional dk = dk_.has_value() ? std::optional(const_cast(dk_.value())) : std::nullopt; - std::optional dv = dv_.has_value() ? std::optional(const_cast(dv_.value())) : std::nullopt; - std::optional alibi_slopes = alibi_slopes_.has_value() ? std::optional(const_cast(alibi_slopes_.value())) : std::nullopt; - - // Convert double to float and int64_t to int. - float p_dropout_float = static_cast(p_dropout); - float softmax_scale_float = static_cast(softmax_scale); - float softcap_float = static_cast(softcap); - int window_size_left_int = static_cast(window_size_left); - int window_size_right_int = static_cast(window_size_right); - - // TODO: avoid copying rng_state if possible - // Create a non-const copy of rng_state - std::optional rng_state_copy; - if (rng_state.has_value()) { - rng_state_copy = rng_state.value().clone(); - } - - return FLASH_NAMESPACE::mha_bwd( - const_cast(dout), - q, k, v, out, softmax_lse, - dq, dk, dv, alibi_slopes, - p_dropout_float, softmax_scale_float, - is_causal, - window_size_left_int, window_size_right_int, - softcap_float, deterministic, - gen, rng_state_copy); +void boxed_mha_bwd(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor dout = to(stack[0]); + Tensor q = to(stack[1]); + Tensor k = to(stack[2]); + Tensor v = to(stack[3]); + Tensor out = to(stack[4]); + Tensor softmax_lse = to(stack[5]); + std::optional dq_ = to>(stack[6]); + std::optional dk_ = to>(stack[7]); + std::optional dv_ = to>(stack[8]); + std::optional alibi_slopes_ = to>(stack[9]); + double p_dropout = to(stack[10]); + double softmax_scale = to(stack[11]); + bool is_causal = to(stack[12]); + int64_t window_size_left = to(stack[13]); + int64_t window_size_right = to(stack[14]); + double softcap = to(stack[15]); + bool deterministic = to(stack[16]); + // stack[17]: Generator? gen_ (ignored) + std::optional rng_state = to>(stack[18]); + + auto out_v = FLASH_NAMESPACE::mha_bwd( + dout, q, k, v, out, softmax_lse, dq_, dk_, dv_, alibi_slopes_, + static_cast(p_dropout), static_cast(softmax_scale), is_causal, + static_cast(window_size_left), static_cast(window_size_right), + static_cast(softcap), deterministic, rng_state); + + stack[0] = from(out_v[0]); + stack[1] = from(out_v[1]); + stack[2] = from(out_v[2]); + stack[3] = from(out_v[3]); } - -std::vector -mha_varlen_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) - const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size - const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size - const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size - const torch::Tensor &softmax_lse, // b x h x seqlen_q - const c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size - const c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size - const c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size - const torch::Tensor &cu_seqlens_q, // batch_size + 1 - const torch::Tensor &cu_seqlens_k, // batch_size + 1 - const c10::optional &alibi_slopes_, // num_heads or b x num_heads - const int64_t max_seqlen_q, - const int64_t max_seqlen_k, - const double p_dropout, - const double softmax_scale, - const bool zero_tensors, - const bool is_causal, - const int64_t window_size_left, - const int64_t window_size_right, - const double softcap, - const bool deterministic, - c10::optional gen_, - const c10::optional &rng_state) { - - auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator()); - - // Prepare the optional arguments as non-const references. - std::optional dq = dq_.has_value() ? std::optional(const_cast(dq_.value())) : std::nullopt; - std::optional dk = dk_.has_value() ? std::optional(const_cast(dk_.value())) : std::nullopt; - std::optional dv = dv_.has_value() ? std::optional(const_cast(dv_.value())) : std::nullopt; - std::optional alibi_slopes = alibi_slopes_.has_value() ? std::optional(const_cast(alibi_slopes_.value())) : std::nullopt; - - // Convert double to float and int64_t to int. - float p_dropout_float = static_cast(p_dropout); - float softmax_scale_float = static_cast(softmax_scale); - float softcap_float = static_cast(softcap); - int max_seqlen_q_int = static_cast(max_seqlen_q); - int max_seqlen_k_int = static_cast(max_seqlen_k); - int window_size_left_int = static_cast(window_size_left); - int window_size_right_int = static_cast(window_size_right); - - - // TODO: avoid copying rng_state if possible - // Create a non-const copy of rng_state - std::optional rng_state_copy; - if (rng_state.has_value()) { - rng_state_copy = rng_state.value().clone(); - } - - return FLASH_NAMESPACE::mha_varlen_bwd( - const_cast(dout), - q, k, v, out, softmax_lse, - dq, dk, dv, - cu_seqlens_q, cu_seqlens_k, - alibi_slopes, - max_seqlen_q_int, max_seqlen_k_int, - p_dropout_float, softmax_scale_float, - zero_tensors, is_causal, - window_size_left_int, window_size_right_int, - softcap_float, deterministic, - gen, rng_state_copy); +void boxed_mha_varlen_bwd(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor dout = to(stack[0]); + Tensor q = to(stack[1]); + Tensor k = to(stack[2]); + Tensor v = to(stack[3]); + Tensor out = to(stack[4]); + Tensor softmax_lse = to(stack[5]); + std::optional dq_ = to>(stack[6]); + std::optional dk_ = to>(stack[7]); + std::optional dv_ = to>(stack[8]); + Tensor cu_seqlens_q = to(stack[9]); + Tensor cu_seqlens_k = to(stack[10]); + std::optional alibi_slopes_ = to>(stack[11]); + int64_t max_seqlen_q = to(stack[12]); + int64_t max_seqlen_k = to(stack[13]); + double p_dropout = to(stack[14]); + double softmax_scale = to(stack[15]); + bool zero_tensors = to(stack[16]); + bool is_causal = to(stack[17]); + int64_t window_size_left = to(stack[18]); + int64_t window_size_right = to(stack[19]); + double softcap = to(stack[20]); + bool deterministic = to(stack[21]); + // stack[22]: Generator? gen_ (ignored) + std::optional rng_state = to>(stack[23]); + + auto out_v = FLASH_NAMESPACE::mha_varlen_bwd( + dout, q, k, v, out, softmax_lse, dq_, dk_, dv_, cu_seqlens_q, cu_seqlens_k, + alibi_slopes_, static_cast(max_seqlen_q), static_cast(max_seqlen_k), + static_cast(p_dropout), static_cast(softmax_scale), zero_tensors, is_causal, + static_cast(window_size_left), static_cast(window_size_right), + static_cast(softcap), deterministic, rng_state); + + stack[0] = from(out_v[0]); + stack[1] = from(out_v[1]); + stack[2] = from(out_v[2]); + stack[3] = from(out_v[3]); } -std::vector -mha_fwd_kvcache(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const torch::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const torch::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. - const c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size - const c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size - const c10::optional &seqlens_k_, // batch_size - const c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - const c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - const c10::optional &cache_batch_idx_, // indices to index into the KV cache - const c10::optional &leftpad_k_, // batch_size - const c10::optional &block_table_, // batch_size x max_num_blocks_per_seq - const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads - const c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size - const double softmax_scale, - bool is_causal, - const int64_t window_size_left, - const int64_t window_size_right, - const double softcap, - bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - const int64_t num_splits) { - - // Prepare the optional arguments as const references where needed - std::optional k = k_.has_value() ? std::optional(k_.value()) : std::nullopt; - std::optional v = v_.has_value() ? std::optional(v_.value()) : std::nullopt; - std::optional seqlens_k = seqlens_k_.has_value() ? std::optional(seqlens_k_.value()) : std::nullopt; - std::optional rotary_cos = rotary_cos_.has_value() ? std::optional(rotary_cos_.value()) : std::nullopt; - std::optional rotary_sin = rotary_sin_.has_value() ? std::optional(rotary_sin_.value()) : std::nullopt; - std::optional cache_batch_idx = cache_batch_idx_.has_value() ? std::optional(cache_batch_idx_.value()) : std::nullopt; - std::optional leftpad_k = leftpad_k_.has_value() ? std::optional(leftpad_k_.value()) : std::nullopt; - - // For non-const tensors - std::optional block_table = block_table_.has_value() ? std::optional(const_cast(block_table_.value())) : std::nullopt; - std::optional alibi_slopes = alibi_slopes_.has_value() ? std::optional(const_cast(alibi_slopes_.value())) : std::nullopt; - std::optional out = out_.has_value() ? std::optional(const_cast(out_.value())) : std::nullopt; - - - // Convert double to float and int64_t to int. - float softmax_scale_float = static_cast(softmax_scale); - float softcap_float = static_cast(softcap); - int window_size_left_int = static_cast(window_size_left); - int window_size_right_int = static_cast(window_size_right); - int num_splits_int = static_cast(num_splits); - - return FLASH_NAMESPACE::mha_fwd_kvcache( - const_cast(q), - kcache, vcache, - k, v, - seqlens_k, - rotary_cos, rotary_sin, - cache_batch_idx, - leftpad_k, - block_table, alibi_slopes, - out, - softmax_scale_float, - is_causal, - window_size_left_int, window_size_right_int, - softcap_float, - is_rotary_interleaved, - num_splits_int - ); -} \ No newline at end of file +void boxed_mha_fwd_kvcache(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor q = to(stack[0]); + Tensor kcache = to(stack[1]); + Tensor vcache = to(stack[2]); + std::optional k_ = to>(stack[3]); + std::optional v_ = to>(stack[4]); + std::optional seqlens_k_ = to>(stack[5]); + std::optional rotary_cos_ = to>(stack[6]); + std::optional rotary_sin_ = to>(stack[7]); + std::optional cache_batch_idx_ = to>(stack[8]); + std::optional leftpad_k_ = to>(stack[9]); + std::optional block_table_ = to>(stack[10]); + std::optional alibi_slopes_ = to>(stack[11]); + std::optional out_ = to>(stack[12]); + double softmax_scale = to(stack[13]); + bool is_causal = to(stack[14]); + int64_t window_size_left = to(stack[15]); + int64_t window_size_right = to(stack[16]); + double softcap = to(stack[17]); + bool is_rotary_interleaved = to(stack[18]); + int64_t num_splits = to(stack[19]); + + auto out = FLASH_NAMESPACE::mha_fwd_kvcache( + q, kcache, vcache, k_, v_, seqlens_k_, rotary_cos_, rotary_sin_, cache_batch_idx_, + leftpad_k_, block_table_, alibi_slopes_, out_, + static_cast(softmax_scale), is_causal, + static_cast(window_size_left), static_cast(window_size_right), + static_cast(softcap), is_rotary_interleaved, static_cast(num_splits)); + + stack[0] = from(out[0]); + stack[1] = from(out[1]); +} diff --git a/flash-attn2/flash_attn/src/cuda_check.h b/flash-attn2/flash_attn/src/cuda_check.h new file mode 100644 index 00000000..9b139612 --- /dev/null +++ b/flash-attn2/flash_attn/src/cuda_check.h @@ -0,0 +1,25 @@ +// Self-contained CUDA error-check macros. +// +// Replaces c10's (C10_CUDA_CHECK / +// C10_CUDA_KERNEL_LAUNCH_CHECK), which hard-errors in stable-ABI builds under +// TORCH_STABLE_ONLY / TORCH_TARGET_VERSION. These are used from .cu translation +// units compiled by nvcc, where the CUDA runtime is available. + +#pragma once + +#include +#include + +#include + +#define C10_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = (EXPR); \ + if (__err != cudaSuccess) { \ + fprintf(stderr, "CUDA error %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(__err)); \ + abort(); \ + } \ + } while (0) + +#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) diff --git a/flash-attn2/flash_attn/src/flash.h b/flash-attn2/flash_attn/src/flash.h index 8ffbb62d..5f71be11 100644 --- a/flash-attn2/flash_attn/src/flash.h +++ b/flash-attn2/flash_attn/src/flash.h @@ -9,7 +9,7 @@ #include #include -#include // For at::Generator and at::PhiloxCudaState +#include "philox_unpack.cuh" // For FLASH_NAMESPACE::PhiloxCudaState namespace FLASH_NAMESPACE { constexpr int TOTAL_DIM = 0; @@ -119,7 +119,7 @@ struct Flash_fwd_params : public Qkv_params { float softcap; // Random state. - at::PhiloxCudaState philox_args; + PhiloxCudaState philox_args; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; diff --git a/flash-attn2/flash_attn/src/flash_bwd_launch_template.h b/flash-attn2/flash_attn/src/flash_bwd_launch_template.h index 72e7a333..ebb1aec7 100644 --- a/flash-attn2/flash_attn/src/flash_bwd_launch_template.h +++ b/flash-attn2/flash_attn/src/flash_bwd_launch_template.h @@ -5,7 +5,7 @@ #pragma once #include "namespace_config.h" -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include "cuda_check.h" // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" #include "hardware_info.h" diff --git a/flash-attn2/flash_attn/src/flash_fwd_kernel.h b/flash-attn2/flash_attn/src/flash_fwd_kernel.h index d492c87b..12e6b361 100644 --- a/flash-attn2/flash_attn/src/flash_fwd_kernel.h +++ b/flash-attn2/flash_attn/src/flash_fwd_kernel.h @@ -66,7 +66,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; - auto seed_offset = at::cuda::philox::unpack(params.philox_args); + auto seed_offset = FLASH_NAMESPACE::philox_compat::unpack(params.philox_args); FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, bidb, bidh, tidx, params.h); diff --git a/flash-attn2/flash_attn/src/flash_fwd_launch_template.h b/flash-attn2/flash_attn/src/flash_fwd_launch_template.h index 934e7b91..04508537 100644 --- a/flash-attn2/flash_attn/src/flash_fwd_launch_template.h +++ b/flash-attn2/flash_attn/src/flash_fwd_launch_template.h @@ -4,7 +4,7 @@ #pragma once #include "namespace_config.h" -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include "cuda_check.h" // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" #include "hardware_info.h" diff --git a/flash-attn2/flash_attn/src/philox_unpack.cuh b/flash-attn2/flash_attn/src/philox_unpack.cuh index 3a54f45c..ab7f3614 100644 --- a/flash-attn2/flash_attn/src/philox_unpack.cuh +++ b/flash-attn2/flash_attn/src/philox_unpack.cuh @@ -1,4 +1,65 @@ -// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh +// Self-contained Philox RNG state + unpack helper. +// +// This replaces ATen's `at::PhiloxCudaState` and `at::cuda::philox::unpack` +// (formerly pulled in via and +// ). Those ATen headers cannot be included in +// stable-ABI builds, where they hard-error under TORCH_STABLE_ONLY. The layout +// and semantics here mirror ATen's so the kernels are unchanged. #pragma once -#include + +#include +#include + +#include "namespace_config.h" + +namespace FLASH_NAMESPACE { + +struct PhiloxCudaState { + PhiloxCudaState() = default; + + // Non-captured (eager) state: literal seed and offset values. + PhiloxCudaState(uint64_t seed, uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + + // Captured (CUDA graph) state: pointers resolved at kernel launch time. + PhiloxCudaState(int64_t *seed, int64_t *offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + union Payload { + uint64_t val; + int64_t *ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +// Namespace deliberately not named `philox`: FLASH_NAMESPACE already declares a +// `philox(...)` function in philox.cuh, which would collide. +namespace philox_compat { + +__host__ __device__ __forceinline__ std::tuple +unpack(PhiloxCudaState arg) { + if (arg.captured_) { + // offset_intragraph_ counts thread-local subsequence usage within a graph. + return std::make_tuple( + static_cast(*arg.seed_.ptr), + static_cast(*arg.offset_.ptr) + arg.offset_intragraph_); + } else { + return std::make_tuple(arg.seed_.val, arg.offset_.val); + } +} + +} // namespace philox_compat + +} // namespace FLASH_NAMESPACE diff --git a/flash-attn2/torch-ext/torch_binding.cpp b/flash-attn2/torch-ext/torch_binding.cpp index b277ea46..76dec209 100644 --- a/flash-attn2/torch-ext/torch_binding.cpp +++ b/flash-attn2/torch-ext/torch_binding.cpp @@ -1,7 +1,12 @@ +// CUDA registers via the Torch stable ABI in torch_binding_stable.cpp / +// flash_attn/flash_api.cpp; the ATen headers below are only available (and +// only needed) for the CPU and XPU backends. +#if !defined(CUDA_KERNEL) #include #include "registration.h" #include "torch_binding.h" +#endif // TODO: Add all of the functions listed // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -13,6 +18,10 @@ // m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache"); // }  +// CUDA registers via the Torch stable ABI in flash_attn/flash_api.cpp; +// this original ATen registration is only for the CPU and XPU backends. +#if !defined(CUDA_KERNEL) + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("fwd(" "Tensor! q, " @@ -28,9 +37,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float softcap, " "bool return_softmax, " "Generator? gen_) -> Tensor[]"); -#if defined(CUDA_KERNEL) - ops.impl("fwd", torch::kCUDA, &mha_fwd); -#elif defined(XPU_KERNEL) +#if defined(XPU_KERNEL) ops.impl("fwd", torch::kXPU, &mha_fwd); #endif @@ -56,9 +63,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float softcap, " "bool return_softmax, " "Generator? gen_) -> Tensor[]"); -#if defined(CUDA_KERNEL) - ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd); -#elif defined(XPU_KERNEL) +#if defined(XPU_KERNEL) ops.impl("varlen_fwd", torch::kXPU, &mha_varlen_fwd); #elif defined(CPU_KERNEL) ops.impl("varlen_fwd", torch::kCPU, &mha_varlen_fwd); @@ -85,9 +90,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool deterministic, " "Generator? gen_, " "Tensor? rng_state) -> Tensor[]"); -#if defined(CUDA_KERNEL) - ops.impl("bwd", torch::kCUDA, &mha_bwd); -#elif defined(XPU_KERNEL) +#if defined(XPU_KERNEL) ops.impl("bwd", torch::kXPU, &mha_bwd); #endif @@ -115,9 +118,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool deterministic, " "Generator? gen_, " "Tensor? rng_state) -> Tensor[]"); -#if defined(CUDA_KERNEL) - ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd); -#elif defined(XPU_KERNEL) +#if defined(XPU_KERNEL) ops.impl("varlen_bwd", torch::kXPU, &mha_varlen_bwd); #endif @@ -142,11 +143,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float softcap, " "bool is_rotary_interleaved, " "int num_splits) -> Tensor[]"); -#if defined(CUDA_KERNEL) - ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache); -#elif defined(XPU_KERNEL) +#if defined(XPU_KERNEL) ops.impl("fwd_kvcache", torch::kXPU, &mha_fwd_kvcache); #endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) + +#endif // !defined(CUDA_KERNEL) diff --git a/flash-attn2/torch-ext/torch_binding.h b/flash-attn2/torch-ext/torch_binding.h index 5dc07397..5efb3478 100644 --- a/flash-attn2/torch-ext/torch_binding.h +++ b/flash-attn2/torch-ext/torch_binding.h @@ -1,5 +1,7 @@ #pragma once +#include + #include // std::tuple diff --git a/flash-attn2/torch-ext/torch_binding_stable.cpp b/flash-attn2/torch-ext/torch_binding_stable.cpp new file mode 100644 index 00000000..28e543fe --- /dev/null +++ b/flash-attn2/torch-ext/torch_binding_stable.cpp @@ -0,0 +1,137 @@ +// CUDA stable-ABI bindings. The XPU/CPU (ATen) bindings live in +// torch_binding.cpp; this file is active only for the CUDA backend. +#if defined(CUDA_KERNEL) + +#include + +#include + +#include "registration.h" + +// Boxed entry points, defined in flash_attn/flash_api.cpp. +void boxed_mha_fwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs); +void boxed_mha_varlen_fwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs); +void boxed_mha_bwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs); +void boxed_mha_varlen_bwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs); +void boxed_mha_fwd_kvcache(StableIValue *stack, uint64_t num_args, uint64_t num_outputs); + +// Schemas return Tensor tuples rather than Tensor[], which the stable ABI cannot box. +STABLE_TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("fwd(" + "Tensor! q, " + "Tensor k, " + "Tensor v, " + "Tensor(out_!)? out_, " + "Tensor? alibi_slopes_, " + "float p_dropout, " + "float softmax_scale, " + "bool is_causal," + "int window_size_left, " + "int window_size_right, " + "float softcap, " + "bool return_softmax, " + "Generator? gen_) -> (Tensor, Tensor, Tensor, Tensor)"); + + ops.def("varlen_fwd(" + "Tensor! q, " + "Tensor k, " + "Tensor v, " + "Tensor? out_, " + "Tensor cu_seqlens_q, " + "Tensor cu_seqlens_k, " + "Tensor? seqused_k_, " + "Tensor? leftpad_k_, " + "Tensor? block_table_, " + "Tensor? alibi_slopes_, " + "int max_seqlen_q, " + "int max_seqlen_k, " + "float p_dropout, " + "float softmax_scale, " + "bool zero_tensors, " + "bool is_causal, " + "int window_size_left, " + "int window_size_right, " + "float softcap, " + "bool return_softmax, " + "Generator? gen_) -> (Tensor, Tensor, Tensor, Tensor)"); + + ops.def("bwd(" + "Tensor! dout, " + "Tensor! q, " + "Tensor! k, " + "Tensor! v, " + "Tensor! out, " + "Tensor! softmax_lse, " + "Tensor? dq_, " + "Tensor? dk_, " + "Tensor? dv_, " + "Tensor? alibi_slopes_, " + "float p_dropout, " + "float softmax_scale, " + "bool is_causal, " + "int window_size_left, " + "int window_size_right, " + "float softcap, " + "bool deterministic, " + "Generator? gen_, " + "Tensor? rng_state) -> (Tensor, Tensor, Tensor, Tensor)"); + + ops.def("varlen_bwd(" + "Tensor! dout, " + "Tensor! q, " + "Tensor! k, " + "Tensor! v, " + "Tensor! out, " + "Tensor! softmax_lse, " + "Tensor? dq_, " + "Tensor? dk_, " + "Tensor? dv_, " + "Tensor cu_seqlens_q, " + "Tensor cu_seqlens_k, " + "Tensor? alibi_slopes_, " + "int max_seqlen_q, " + "int max_seqlen_k, " + "float p_dropout, float softmax_scale, " + "bool zero_tensors, " + "bool is_causal, " + "int window_size_left, " + "int window_size_right, " + "float softcap, " + "bool deterministic, " + "Generator? gen_, " + "Tensor? rng_state) -> (Tensor, Tensor, Tensor, Tensor)"); + + ops.def("fwd_kvcache(" + "Tensor! q, " + "Tensor! kcache, " + "Tensor! vcache, " + "Tensor? k_, " + "Tensor? v_, " + "Tensor? seqlens_k_, " + "Tensor? rotary_cos_, " + "Tensor? rotary_sin_, " + "Tensor? cache_batch_idx_, " + "Tensor? leftpad_k_, " + "Tensor? block_table_, " + "Tensor? alibi_slopes_, " + "Tensor? out_, " + "float softmax_scale, " + "bool is_causal, " + "int window_size_left, " + "int window_size_right, " + "float softcap, " + "bool is_rotary_interleaved, " + "int num_splits) -> (Tensor, Tensor)"); +} + +STABLE_TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) { + ops.impl("fwd", &boxed_mha_fwd); + ops.impl("varlen_fwd", &boxed_mha_varlen_fwd); + ops.impl("bwd", &boxed_mha_bwd); + ops.impl("varlen_bwd", &boxed_mha_varlen_bwd); + ops.impl("fwd_kvcache", &boxed_mha_fwd_kvcache); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) + +#endif // defined(CUDA_KERNEL)