Skip to content

Commit 6b2da2e

Browse files
committed
Add supports_flash_attn to remove duplicated code
1 parent 29fe799 commit 6b2da2e

2 files changed

Lines changed: 23 additions & 27 deletions

File tree

backends/candle/src/flash_attn.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use candle::Tensor;
1+
use candle::{DType, Device, Tensor};
22
use std::sync::Once;
33

44
static INIT: Once = Once::new();
55
static mut RUNTIME_COMPUTE_CAP: usize = 0;
6+
67
fn init_runtime_compute_cap() {
78
unsafe {
89
INIT.call_once(|| {
@@ -19,6 +20,15 @@ pub fn get_runtime_compute_cap() -> usize {
1920
}
2021
}
2122

23+
pub fn supports_flash_attn(dtype: &DType, device: &Device) -> bool {
24+
(dtype == DType::F16 || dtype == DType::BF16)
25+
&& cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
26+
&& &std::env::var("USE_FLASH_ATTENTION")
27+
.unwrap_or("True".to_string())
28+
.to_lowercase()
29+
== "true"
30+
}
31+
2232
#[allow(clippy::too_many_arguments, unused)]
2333
pub(crate) fn flash_attn_varlen(
2434
q: &Tensor,

backends/candle/src/lib.rs

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ use text_embeddings_backend_core::{
1717
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
1818
};
1919

20+
#[cfg(feature = "cuda")]
21+
use crate::flash_attn::supports_flash_attn;
22+
2023
#[cfg(feature = "cuda")]
2124
use crate::compute_cap::{
2225
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
@@ -349,12 +352,7 @@ impl CandleBackend {
349352
}
350353
#[cfg(feature = "cuda")]
351354
(Config::Bert(config), Device::Cuda(_)) => {
352-
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
353-
&& dtype == DType::F16
354-
// Allow disabling because of flash attention v1 precision problems
355-
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
356-
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
357-
{
355+
if supports_flash_attn(&dtype, &device) {
358356
match config {
359357
BertConfigWrapper::JinaBert(config) => {
360358
tracing::info!("Starting FlashJinaBert model on {:?}", device);
@@ -447,18 +445,12 @@ impl CandleBackend {
447445
}
448446
#[cfg(feature = "cuda")]
449447
(Config::Gte(config), Device::Cuda(_)) => {
450-
if dtype != DType::F16
451-
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
452-
|| &std::env::var("USE_FLASH_ATTENTION")
453-
.unwrap_or("True".to_string())
454-
.to_lowercase()
455-
!= "true"
456-
{
457-
tracing::info!("Starting GTE model on {:?}", device);
458-
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
459-
} else {
448+
if supports_flash_attn(&dtype, &device) {
460449
tracing::info!("Starting FlashGTE model on {:?}", device);
461450
Ok(Box::new(FlashGTEModel::load(vb, &config, model_type).s()?))
451+
} else {
452+
tracing::info!("Starting GTE model on {:?}", device);
453+
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
462454
}
463455
}
464456
#[cfg(feature = "cuda")]
@@ -535,20 +527,14 @@ impl CandleBackend {
535527
}
536528
#[cfg(feature = "cuda")]
537529
(Config::Qwen3(config), Device::Cuda(_)) => {
538-
if (dtype != DType::F16 && dtype != DType::BF16)
539-
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
540-
|| &std::env::var("USE_FLASH_ATTENTION")
541-
.unwrap_or("True".to_string())
542-
.to_lowercase()
543-
!= "true"
544-
{
545-
tracing::info!("Starting Qwen3 model on {:?}", device);
546-
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
547-
} else {
530+
if supports_flash_attn(&dtype, &device) {
548531
tracing::info!("Starting FlashQwen3 model on {:?}", device);
549532
Ok(Box::new(
550533
FlashQwen3Model::load(vb, &config, model_type).s()?,
551534
))
535+
} else {
536+
tracing::info!("Starting Qwen3 model on {:?}", device);
537+
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
552538
}
553539
}
554540
};

0 commit comments

Comments
 (0)