diff --git a/Cargo.lock b/Cargo.lock index 58bffaba1..28b81c512 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -550,6 +550,19 @@ dependencies = [ "half", ] +[[package]] +name = "candle-flash-attn-3" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d952728fd456ab48f94ed46e0ef761e5b68bdd7eba5d45cb0638eb37f7d700c" +dependencies = [ + "anyhow", + "candle-core", + "half", + "num_cpus", + "rayon", +] + [[package]] name = "candle-flash-attn-v1" version = "0.0.1" @@ -4513,6 +4526,7 @@ dependencies = [ "candle-core", "candle-cublaslt", "candle-flash-attn", + "candle-flash-attn-3", "candle-flash-attn-v1", "candle-index-select-cu", "candle-layer-norm", diff --git a/Cargo.toml b/Cargo.toml index 77a89b584..f4a9914e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ candle = { version = "0.8", package = "candle-core" } candle-nn = { version = "0.8" } candle-transformers = { version = "0.8" } candle-flash-attn = { version = "0.8" } +candle-flash-attn-3 = { version = "0.0.1", features = ["cuda-11"], default-features = false } candle-cublaslt = { version = "0.0.1" } candle-layer-norm = { version = "0.0.1" } candle-index-select-cu = { version = "0.0.1", features = ["cuda-11"], default-features = false } diff --git a/backends/Cargo.toml b/backends/Cargo.toml index bb9d74191..40b94b4cc 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -29,3 +29,4 @@ mkl = ["text-embeddings-backend-candle?/mkl"] accelerate = ["text-embeddings-backend-candle?/accelerate"] flash-attn = ["text-embeddings-backend-candle?/flash-attn"] flash-attn-v1 = ["text-embeddings-backend-candle?/flash-attn-v1"] +flash-attn-v3 = ["text-embeddings-backend-candle?/flash-attn-v3"] diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index 1dbbf2ca8..4b1e33c1f 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -14,6 +14,7 @@ candle-nn = { workspace = true } candle-transformers = { workspace = true } candle-flash-attn = { workspace = true, optional = true} candle-flash-attn-v1 = { workspace = true, optional = true } +candle-flash-attn-3 = { workspace = true, optional = true } candle-cublaslt = { workspace = true, optional = true } candle-index-select-cu = { workspace = true, optional = true, features = ["cuda-11"], default-features = false} candle-layer-norm = { workspace = true, optional = true } @@ -45,3 +46,4 @@ mkl = ["dep:intel-mkl-src", "candle/_mkl"] cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary", "dep:candle-index-select-cu"] flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"] flash-attn = ["dep:candle-flash-attn", "cuda"] +flash-attn-v3 = ["dep:candle-flash-attn-3", "cuda"] diff --git a/backends/candle/src/flash_attn.rs b/backends/candle/src/flash_attn.rs index 8dbe58cf0..ac840dadb 100644 --- a/backends/candle/src/flash_attn.rs +++ b/backends/candle/src/flash_attn.rs @@ -19,6 +19,14 @@ pub fn get_runtime_compute_cap() -> usize { } } +static PRINT_ONCE: Once = Once::new(); + +fn print_version_once(version: &str) { + PRINT_ONCE.call_once(|| { + println!("Using michaelfeil/candle-flash-attn-3 v{}", version); + }); +} + #[allow(clippy::too_many_arguments, unused)] pub(crate) fn flash_attn_varlen( q: &Tensor, @@ -62,7 +70,53 @@ pub(crate) fn flash_attn_varlen( #[cfg(not(feature = "flash-attn-v1"))] candle::bail!("Flash attention v1 is not installed. Use `flash-attn-v1` feature.") } else if (80..90).contains(&runtime_compute_cap) || runtime_compute_cap == 90 { - #[cfg(feature = "flash-attn")] + #[cfg(feature = "flash-attn-v3")] + { + use candle_flash_attn_v3::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed}; + print_version_once("0.0.1"); + let window_size_right = if causal { + Some(0) + } else if window_size_right.is_some() { + window_size_right + } else { + None + }; + + let attention = if let Some(alibi_slopes) = alibi_slopes { + flash_attn_varlen_alibi_windowed( + q, + k, + v, + alibi_slopes, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + window_size_left, + window_size_right, + false, // use_gqa_packing - set to false for now + ) + } else { + flash_attn_varlen_windowed( + q, + k, + v, + seqlens_q, + seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + window_size_left, + window_size_right, + false, // use_gqa_packing - set to false for now + ) + }; + + return attention; + } + + #[cfg(all(not(feature = "flash-attn-v3"), feature = "flash-attn"))] { use candle_flash_attn::{flash_attn_varlen_alibi_windowed, flash_attn_varlen_windowed}; @@ -105,8 +159,8 @@ pub(crate) fn flash_attn_varlen( return attention; } - #[cfg(not(feature = "flash-attn"))] - candle::bail!("Flash attention is not installed. Use `flash-attn` feature.") + #[cfg(not(any(feature = "flash-attn-v3", feature = "flash-attn")))] + candle::bail!("Flash attention is not installed. Use `flash-attn-v3` or `flash-attn` feature.") } candle::bail!( "GPU with CUDA capability {} is not supported",