Skip to content

Commit ea1b6b0

Browse files
committed
fix: Flash Attention compatibility check for SM_1xx (RTX 5000 series)
Fixed build error Includes the original commit from PR OpenNMT#1873
1 parent 41eccad commit ea1b6b0

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/models/model.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,14 +844,15 @@ namespace ctranslate2 {
844844
" running independently a model in each device");
845845
}
846846

847+
bool supports_flash_attention = false;
847848
if (device == Device::CUDA) {
848849
int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA);
849850
auto dprops = ctranslate2::cuda::get_device_properties(device_id);
850851
float compute_capability = dprops.major + (dprops.minor / 10.0f);
851852

852853
// Minimum compute capability for Flash Attention is Ampere (8.0)
853854
const float min_flash_attn_compute_capability = 8.0f;
854-
bool supports_flash_attention = compute_capability >= min_flash_attn_compute_capability;
855+
supports_flash_attention = compute_capability >= min_flash_attn_compute_capability;
855856
}
856857

857858
if (use_flash_attention && (device != Device::CUDA || !supports_flash_attention)) {

0 commit comments

Comments
 (0)