Skip to content

Commit 9755745

Browse files
added thd a100 guard
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
1 parent 720ec27 commit 9755745

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -399,16 +399,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
399399
// qkv format
400400
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
401401
qkv_format == NVTE_QKV_Format::NVTE_BHSD ||
402-
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
402+
(qkv_format == NVTE_QKV_Format::NVTE_THD &&
403+
(sm_arch_ >= 90 || cudnn_runtime_version >= 91801) &&
403404
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
404405
cudnn_runtime_version >= 90600)) ||
405406
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD ||
406407
q_format == NVTE_QKV_Format::NVTE_BHSD ||
407-
(q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
408+
(q_format == NVTE_QKV_Format::NVTE_THD &&
409+
(sm_arch_ >= 90 || cudnn_runtime_version >= 91801)) ||
408410
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD ||
409411
kv_format == NVTE_QKV_Format::NVTE_BHSD ||
410-
(kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
412+
(kv_format == NVTE_QKV_Format::NVTE_THD &&
413+
(sm_arch_ >= 90 || cudnn_runtime_version >= 91801))) &&
411414
cudnn_runtime_version >= 90700)) &&
415+
// THD (ragged offset) support: Ampere/Ada (sm80/sm89) only from cuDNN 9.18.1
416+
((q_format != NVTE_QKV_Format::NVTE_THD && kv_format != NVTE_QKV_Format::NVTE_THD) ||
417+
sm_arch_ >= 90 || cudnn_runtime_version >= 91801) &&
412418
// sliding window
413419
// pre-9.2: full attn, causal
414420
((cudnn_runtime_version < 90200 && window_size_left == -1 &&

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,18 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
996996
qkv_layout,
997997
)
998998
use_fused_attention = False
999+
# THD support on Ampere/Ada requires cuDNN 9.18.1+ ("SDPA backward with THD layout on
1000+
# RTX-PRO 6000 and Ampere-architecture GPUs"). Check q_format/kv_format, not just
1001+
# qkv_format, since KV-cache layouts (e.g. paged_kv_thd_bshd_bshd) have
1002+
# qkv_format = thd_2bshd.
1003+
if "thd" in (q_format, kv_format) and device_compute_capability < (9, 0):
1004+
if cudnn_version < (9, 18, 1):
1005+
if use_fused_attention:
1006+
logger.debug(
1007+
"Disabling FusedAttention as qkv_format = thd is not supported for"
1008+
" compute capability < sm90 and cuDNN version < 9.18.1"
1009+
)
1010+
use_fused_attention = False
9991011

10001012
# Filter: Dropout
10011013
if attention_dropout != 0.0:

0 commit comments

Comments
 (0)