File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -535,7 +535,7 @@ impl CandleBackend {
535535 }
536536 #[ cfg( feature = "cuda" ) ]
537537 ( Config :: Qwen3 ( config) , Device :: Cuda ( _) ) => {
538- if dtype != DType :: F16
538+ if ( dtype != DType :: F16 && dtype != DType :: BF16 )
539539 || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
540540 || & std:: env:: var ( "USE_FLASH_ATTENTION" )
541541 . unwrap_or ( "True" . to_string ( ) )
Original file line number Diff line number Diff line change @@ -300,8 +300,8 @@ impl FlashQwen3Model {
300300 _ => candle:: bail!( "FlashQwen3 requires Cuda" ) ,
301301 }
302302
303- if vb. dtype ( ) != DType :: F16 {
304- candle:: bail!( "FlashQwen3 requires DType::F16" )
303+ if vb. dtype ( ) != DType :: F16 && vb . dtype ( ) != DType :: BF16 {
304+ candle:: bail!( "FlashQwen3 requires DType::F16 or DType::BF16 " )
305305 }
306306
307307 let pool = match model_type {
You can’t perform that action at this time.
0 commit comments