Skip to content

Commit 9387804

Browse files
committed
Add BF16 support for FlashQwen3
1 parent 9745b0e commit 9387804

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

backends/candle/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ impl CandleBackend {
535535
}
536536
#[cfg(feature = "cuda")]
537537
(Config::Qwen3(config), Device::Cuda(_)) => {
538-
if dtype != DType::F16
538+
if (dtype != DType::F16 && dtype != DType::BF16)
539539
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
540540
|| &std::env::var("USE_FLASH_ATTENTION")
541541
.unwrap_or("True".to_string())

backends/candle/src/models/flash_qwen3.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,8 @@ impl FlashQwen3Model {
300300
_ => candle::bail!("FlashQwen3 requires Cuda"),
301301
}
302302

303-
if vb.dtype() != DType::F16 {
304-
candle::bail!("FlashQwen3 requires DType::F16")
303+
if vb.dtype() != DType::F16 && vb.dtype() != DType::BF16 {
304+
candle::bail!("FlashQwen3 requires DType::F16 or DType::BF16")
305305
}
306306

307307
let pool = match model_type {

0 commit comments

Comments
 (0)