@@ -17,6 +17,9 @@ use text_embeddings_backend_core::{
1717 Backend , BackendError , Batch , Embedding , Embeddings , ModelType , Predictions ,
1818} ;
1919
20+ #[ cfg( feature = "cuda" ) ]
21+ use crate :: flash_attn:: supports_flash_attn;
22+
2023#[ cfg( feature = "cuda" ) ]
2124use crate :: compute_cap:: {
2225 compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
@@ -349,12 +352,7 @@ impl CandleBackend {
349352 }
350353 #[ cfg( feature = "cuda" ) ]
351354 ( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
352- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
353- && dtype == DType :: F16
354- // Allow disabling because of flash attention v1 precision problems
355- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
356- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
357- {
355+ if supports_flash_attn ( & dtype, & device) {
358356 match config {
359357 BertConfigWrapper :: JinaBert ( config) => {
360358 tracing:: info!( "Starting FlashJinaBert model on {:?}" , device) ;
@@ -447,18 +445,12 @@ impl CandleBackend {
447445 }
448446 #[ cfg( feature = "cuda" ) ]
449447 ( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
450- if dtype != DType :: F16
451- || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
452- || & std:: env:: var ( "USE_FLASH_ATTENTION" )
453- . unwrap_or ( "True" . to_string ( ) )
454- . to_lowercase ( )
455- != "true"
456- {
457- tracing:: info!( "Starting GTE model on {:?}" , device) ;
458- Ok ( Box :: new ( GTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
459- } else {
448+ if supports_flash_attn ( & dtype, & device) {
460449 tracing:: info!( "Starting FlashGTE model on {:?}" , device) ;
461450 Ok ( Box :: new ( FlashGTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
451+ } else {
452+ tracing:: info!( "Starting GTE model on {:?}" , device) ;
453+ Ok ( Box :: new ( GTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
462454 }
463455 }
464456 #[ cfg( feature = "cuda" ) ]
@@ -535,20 +527,14 @@ impl CandleBackend {
535527 }
536528 #[ cfg( feature = "cuda" ) ]
537529 ( Config :: Qwen3 ( config) , Device :: Cuda ( _) ) => {
538- if ( dtype != DType :: F16 && dtype != DType :: BF16 )
539- || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
540- || & std:: env:: var ( "USE_FLASH_ATTENTION" )
541- . unwrap_or ( "True" . to_string ( ) )
542- . to_lowercase ( )
543- != "true"
544- {
545- tracing:: info!( "Starting Qwen3 model on {:?}" , device) ;
546- Ok ( Box :: new ( Qwen3Model :: load ( vb, & config, model_type) . s ( ) ?) )
547- } else {
530+ if supports_flash_attn ( & dtype, & device) {
548531 tracing:: info!( "Starting FlashQwen3 model on {:?}" , device) ;
549532 Ok ( Box :: new (
550533 FlashQwen3Model :: load ( vb, & config, model_type) . s ( ) ?,
551534 ) )
535+ } else {
536+ tracing:: info!( "Starting Qwen3 model on {:?}" , device) ;
537+ Ok ( Box :: new ( Qwen3Model :: load ( vb, & config, model_type) . s ( ) ?) )
552538 }
553539 }
554540 } ;
0 commit comments