Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,16 +399,22 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// qkv format
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_BHSD ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
(qkv_format == NVTE_QKV_Format::NVTE_THD &&
(sm_arch_ >= 90 || cudnn_runtime_version >= 91801) &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600)) ||
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD ||
q_format == NVTE_QKV_Format::NVTE_BHSD ||
(q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
(q_format == NVTE_QKV_Format::NVTE_THD &&
(sm_arch_ >= 90 || cudnn_runtime_version >= 91801)) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD ||
kv_format == NVTE_QKV_Format::NVTE_BHSD ||
(kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
(kv_format == NVTE_QKV_Format::NVTE_THD &&
(sm_arch_ >= 90 || cudnn_runtime_version >= 91801))) &&
cudnn_runtime_version >= 90700)) &&
// THD (ragged offset) support: Ampere/Ada (sm80/sm89) only from cuDNN 9.18.1
((q_format != NVTE_QKV_Format::NVTE_THD && kv_format != NVTE_QKV_Format::NVTE_THD) ||
sm_arch_ >= 90 || cudnn_runtime_version >= 91801) &&
// sliding window
// pre-9.2: full attn, causal
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
bool use_ragged_stats =
is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -98,10 +99,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t actual_b = b;
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3]
// as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build
// so the check passes; ragged offset still provides variable-length boundaries.
if (sm_arch_ != 120) {
// On SM8X/SM12X, cuDNN requires BHSD-like strides with max_seqlen at plan build.
if (sm_arch_ >= 90 && sm_arch_ != 120) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
Expand Down Expand Up @@ -385,7 +384,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}

Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
Expand Down Expand Up @@ -590,7 +589,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
bool use_ragged_stats =
is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;

NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
Expand All @@ -602,8 +602,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t actual_b = b;
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd).
if (sm_arch_ != 120) {
// On SM8X/SM12X, cuDNN requires BHSD-like strides with max_seqlen at plan build.
if (sm_arch_ >= 90 && sm_arch_ != 120) {
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
Expand Down Expand Up @@ -805,7 +805,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (use_ragged_stats) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {
if (is_ragged_kv && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}

Expand Down Expand Up @@ -1139,10 +1139,13 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool use_ragged_stats =
is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;

Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if (use_ragged_stats) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand All @@ -1152,8 +1155,7 @@ void fused_attn_arbitrary_seqlen_fwd(
if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
(sm_arch_ != 120)) {
if (use_ragged_stats) {
output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
use_flash_attention_4 = False

# Filter: QKV layout
if qkv_format == "thd":
if "thd" in (q_format, kv_format):
if pad_between_seqs:
if ( # pylint: disable=too-many-boolean-expressions
use_flash_attention_2 and FlashAttentionUtils.is_installed
Expand Down Expand Up @@ -1001,6 +1001,18 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
qkv_layout,
)
use_fused_attention = False
# THD support on Ampere/Ada requires cuDNN 9.18.1+ ("SDPA backward with THD layout on
# RTX-PRO 6000 and Ampere-architecture GPUs"). Check q_format/kv_format, not just
# qkv_format, since KV-cache layouts (e.g. paged_kv_thd_bshd_bshd) have
# qkv_format = thd_2bshd.
if "thd" in (q_format, kv_format) and device_compute_capability < (9, 0):
if cudnn_version < (9, 18, 1):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as q_format or kv_format = thd is not supported for"
" compute capability < sm90 and cuDNN version < 9.18.1"
)
Comment thread
cyanguwa marked this conversation as resolved.
use_fused_attention = False

# Filter: Dropout
if attention_dropout != 0.0:
Expand Down