Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flash-attn2/build.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[general]
name = "flash-attn2"
version = 1
version = 3
license = "BSD-3-Clause"
backends = [
"cpu",
Expand All @@ -12,8 +12,10 @@ backends = [
repo-id = "kernels-community/flash-attn2"

[torch]
stable-abi = "2.10"
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding_stable.cpp",
"torch-ext/torch_binding.h",
]

Expand Down Expand Up @@ -180,6 +182,7 @@ depends = [
src = [
"flash_attn/flash_api.cpp",
"flash_attn/src/philox_unpack.cuh",
"flash_attn/src/cuda_check.h",
"flash_attn/src/namespace_config.h",
"flash_attn/src/hardware_info.h",
"flash_attn/src/flash.h",
Expand Down
1,223 changes: 581 additions & 642 deletions flash-attn2/flash_attn/flash_api.cpp

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions flash-attn2/flash_attn/src/cuda_check.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Self-contained CUDA error-check macros.
//
// Replaces c10's <c10/cuda/CUDAException.h> (C10_CUDA_CHECK /
// C10_CUDA_KERNEL_LAUNCH_CHECK), which hard-errors in stable-ABI builds under
// TORCH_STABLE_ONLY / TORCH_TARGET_VERSION. These are used from .cu translation
// units compiled by nvcc, where the CUDA runtime is available.

#pragma once

#include <cstdio>
#include <cstdlib>

#include <cuda_runtime.h>

#define C10_CUDA_CHECK(EXPR) \
do { \
const cudaError_t __err = (EXPR); \
if (__err != cudaSuccess) { \
fprintf(stderr, "CUDA error %s:%d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(__err)); \
abort(); \
} \
} while (0)

#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
4 changes: 2 additions & 2 deletions flash-attn2/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <cuda.h>
#include <vector>

#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
#include "philox_unpack.cuh" // For FLASH_NAMESPACE::PhiloxCudaState

namespace FLASH_NAMESPACE {
constexpr int TOTAL_DIM = 0;
Expand Down Expand Up @@ -119,7 +119,7 @@ struct Flash_fwd_params : public Qkv_params {
float softcap;

// Random state.
at::PhiloxCudaState philox_args;
PhiloxCudaState philox_args;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
Expand Down
2 changes: 1 addition & 1 deletion flash-attn2/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "namespace_config.h"
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "cuda_check.h" // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK

#include "static_switch.h"
#include "hardware_info.h"
Expand Down
2 changes: 1 addition & 1 deletion flash-attn2/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;

auto seed_offset = at::cuda::philox::unpack(params.philox_args);
auto seed_offset = FLASH_NAMESPACE::philox_compat::unpack(params.philox_args);
FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);

Expand Down
2 changes: 1 addition & 1 deletion flash-attn2/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once
#include "namespace_config.h"
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include "cuda_check.h" // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK

#include "static_switch.h"
#include "hardware_info.h"
Expand Down
65 changes: 63 additions & 2 deletions flash-attn2/flash_attn/src/philox_unpack.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,65 @@
// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh
// Self-contained Philox RNG state + unpack helper.
//
// This replaces ATen's `at::PhiloxCudaState` and `at::cuda::philox::unpack`
// (formerly pulled in via <ATen/cuda/CUDAGeneratorImpl.h> and
// <ATen/cuda/detail/UnpackRaw.cuh>). Those ATen headers cannot be included in
// stable-ABI builds, where they hard-error under TORCH_STABLE_ONLY. The layout
// and semantics here mirror ATen's so the kernels are unchanged.

#pragma once
#include <ATen/cuda/detail/UnpackRaw.cuh>

#include <cstdint>
#include <tuple>

#include "namespace_config.h"

namespace FLASH_NAMESPACE {

struct PhiloxCudaState {
PhiloxCudaState() = default;

// Non-captured (eager) state: literal seed and offset values.
PhiloxCudaState(uint64_t seed, uint64_t offset) {
seed_.val = seed;
offset_.val = offset;
}

// Captured (CUDA graph) state: pointers resolved at kernel launch time.
PhiloxCudaState(int64_t *seed, int64_t *offset_extragraph,
uint32_t offset_intragraph) {
seed_.ptr = seed;
offset_.ptr = offset_extragraph;
offset_intragraph_ = offset_intragraph;
captured_ = true;
}

union Payload {
uint64_t val;
int64_t *ptr;
};

Payload seed_{};
Payload offset_{};
uint32_t offset_intragraph_ = 0;
bool captured_ = false;
};

// Namespace deliberately not named `philox`: FLASH_NAMESPACE already declares a
// `philox(...)` function in philox.cuh, which would collide.
namespace philox_compat {

__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
unpack(PhiloxCudaState arg) {
if (arg.captured_) {
// offset_intragraph_ counts thread-local subsequence usage within a graph.
return std::make_tuple(
static_cast<uint64_t>(*arg.seed_.ptr),
static_cast<uint64_t>(*arg.offset_.ptr) + arg.offset_intragraph_);
} else {
return std::make_tuple(arg.seed_.val, arg.offset_.val);
}
}

} // namespace philox_compat

} // namespace FLASH_NAMESPACE
31 changes: 16 additions & 15 deletions flash-attn2/torch-ext/torch_binding.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// CUDA registers via the Torch stable ABI in torch_binding_stable.cpp /
// flash_attn/flash_api.cpp; the ATen headers below are only available (and
// only needed) for the CPU and XPU backends.
#if !defined(CUDA_KERNEL)
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"
#endif

// TODO: Add all of the functions listed
// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand All @@ -13,6 +18,10 @@
// m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
// } 

// CUDA registers via the Torch stable ABI in flash_attn/flash_api.cpp;
// this original ATen registration is only for the CPU and XPU backends.
#if !defined(CUDA_KERNEL)

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd("
"Tensor! q, "
Expand All @@ -28,9 +37,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float softcap, "
"bool return_softmax, "
"Generator? gen_) -> Tensor[]");
#if defined(CUDA_KERNEL)
ops.impl("fwd", torch::kCUDA, &mha_fwd);
#elif defined(XPU_KERNEL)
#if defined(XPU_KERNEL)
ops.impl("fwd", torch::kXPU, &mha_fwd);
#endif

Expand All @@ -56,9 +63,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float softcap, "
"bool return_softmax, "
"Generator? gen_) -> Tensor[]");
#if defined(CUDA_KERNEL)
ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
#elif defined(XPU_KERNEL)
#if defined(XPU_KERNEL)
ops.impl("varlen_fwd", torch::kXPU, &mha_varlen_fwd);
#elif defined(CPU_KERNEL)
ops.impl("varlen_fwd", torch::kCPU, &mha_varlen_fwd);
Expand All @@ -85,9 +90,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"bool deterministic, "
"Generator? gen_, "
"Tensor? rng_state) -> Tensor[]");
#if defined(CUDA_KERNEL)
ops.impl("bwd", torch::kCUDA, &mha_bwd);
#elif defined(XPU_KERNEL)
#if defined(XPU_KERNEL)
ops.impl("bwd", torch::kXPU, &mha_bwd);
#endif

Expand Down Expand Up @@ -115,9 +118,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"bool deterministic, "
"Generator? gen_, "
"Tensor? rng_state) -> Tensor[]");
#if defined(CUDA_KERNEL)
ops.impl("varlen_bwd", torch::kCUDA, &mha_varlen_bwd);
#elif defined(XPU_KERNEL)
#if defined(XPU_KERNEL)
ops.impl("varlen_bwd", torch::kXPU, &mha_varlen_bwd);
#endif

Expand All @@ -142,11 +143,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float softcap, "
"bool is_rotary_interleaved, "
"int num_splits) -> Tensor[]");
#if defined(CUDA_KERNEL)
ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
#elif defined(XPU_KERNEL)
#if defined(XPU_KERNEL)
ops.impl("fwd_kvcache", torch::kXPU, &mha_fwd_kvcache);
#endif
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

#endif // !defined(CUDA_KERNEL)
2 changes: 2 additions & 0 deletions flash-attn2/torch-ext/torch_binding.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <vector>

#include <torch/torch.h>

// std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand Down
137 changes: 137 additions & 0 deletions flash-attn2/torch-ext/torch_binding_stable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// CUDA stable-ABI bindings. The XPU/CPU (ATen) bindings live in
// torch_binding.cpp; this file is active only for the CUDA backend.
#if defined(CUDA_KERNEL)

#include <cstdint>

#include <torch/csrc/stable/library.h>

#include "registration.h"

// Boxed entry points, defined in flash_attn/flash_api.cpp.
void boxed_mha_fwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs);
void boxed_mha_varlen_fwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs);
void boxed_mha_bwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs);
void boxed_mha_varlen_bwd(StableIValue *stack, uint64_t num_args, uint64_t num_outputs);
void boxed_mha_fwd_kvcache(StableIValue *stack, uint64_t num_args, uint64_t num_outputs);

// Schemas return Tensor tuples rather than Tensor[], which the stable ABI cannot box.
STABLE_TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd("
"Tensor! q, "
"Tensor k, "
"Tensor v, "
"Tensor(out_!)? out_, "
"Tensor? alibi_slopes_, "
"float p_dropout, "
"float softmax_scale, "
"bool is_causal,"
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool return_softmax, "
"Generator? gen_) -> (Tensor, Tensor, Tensor, Tensor)");

ops.def("varlen_fwd("
"Tensor! q, "
"Tensor k, "
"Tensor v, "
"Tensor? out_, "
"Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, "
"Tensor? seqused_k_, "
"Tensor? leftpad_k_, "
"Tensor? block_table_, "
"Tensor? alibi_slopes_, "
"int max_seqlen_q, "
"int max_seqlen_k, "
"float p_dropout, "
"float softmax_scale, "
"bool zero_tensors, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool return_softmax, "
"Generator? gen_) -> (Tensor, Tensor, Tensor, Tensor)");

ops.def("bwd("
"Tensor! dout, "
"Tensor! q, "
"Tensor! k, "
"Tensor! v, "
"Tensor! out, "
"Tensor! softmax_lse, "
"Tensor? dq_, "
"Tensor? dk_, "
"Tensor? dv_, "
"Tensor? alibi_slopes_, "
"float p_dropout, "
"float softmax_scale, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool deterministic, "
"Generator? gen_, "
"Tensor? rng_state) -> (Tensor, Tensor, Tensor, Tensor)");

ops.def("varlen_bwd("
"Tensor! dout, "
"Tensor! q, "
"Tensor! k, "
"Tensor! v, "
"Tensor! out, "
"Tensor! softmax_lse, "
"Tensor? dq_, "
"Tensor? dk_, "
"Tensor? dv_, "
"Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, "
"Tensor? alibi_slopes_, "
"int max_seqlen_q, "
"int max_seqlen_k, "
"float p_dropout, float softmax_scale, "
"bool zero_tensors, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool deterministic, "
"Generator? gen_, "
"Tensor? rng_state) -> (Tensor, Tensor, Tensor, Tensor)");

ops.def("fwd_kvcache("
"Tensor! q, "
"Tensor! kcache, "
"Tensor! vcache, "
"Tensor? k_, "
"Tensor? v_, "
"Tensor? seqlens_k_, "
"Tensor? rotary_cos_, "
"Tensor? rotary_sin_, "
"Tensor? cache_batch_idx_, "
"Tensor? leftpad_k_, "
"Tensor? block_table_, "
"Tensor? alibi_slopes_, "
"Tensor? out_, "
"float softmax_scale, "
"bool is_causal, "
"int window_size_left, "
"int window_size_right, "
"float softcap, "
"bool is_rotary_interleaved, "
"int num_splits) -> (Tensor, Tensor)");
}

STABLE_TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, ops) {
ops.impl("fwd", &boxed_mha_fwd);
ops.impl("varlen_fwd", &boxed_mha_varlen_fwd);
ops.impl("bwd", &boxed_mha_bwd);
ops.impl("varlen_bwd", &boxed_mha_varlen_bwd);
ops.impl("fwd_kvcache", &boxed_mha_fwd_kvcache);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

#endif // defined(CUDA_KERNEL)