Skip to content

Commit 28608a7

Browse files
requested changes
Signed-off-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
1 parent 8d055a4 commit 28608a7

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
8686
const auto cudnn_runtime_version = cudnnGetVersion();
8787
const int device_id = cuda::current_device();
8888
const int sm_arch_ = cuda::sm_arch(device_id);
89-
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
89+
bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
9090

9191
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
9292
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
@@ -98,10 +98,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
9898
int64_t actual_b = b;
9999
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
100100
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
101-
// On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3]
102-
// as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build
103-
// so the check passes; ragged offset still provides variable-length boundaries.
104-
if (sm_arch_ != 120) {
101+
// On SM8X/SM12X, cuDNN requires BHSD-like strides with max_seqlen at plan build.
102+
if (sm_arch_ >= 90 && sm_arch_ != 120) {
105103
// replace batch size and maximum sequence lengths with maximum token counts
106104
// for query and key/value so the graph is static within each quantization bucket
107105
b = max_b;
@@ -385,7 +383,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
385383
}
386384

387385
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
388-
if (is_ragged_q && cudnn_runtime_version >= 90600) {
386+
if (use_ragged_stats) {
389387
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
390388
} else {
391389
Stats->set_stride({h * s_q, s_q, 1, 1});
@@ -590,7 +588,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
590588
const auto cudnn_runtime_version = cudnnGetVersion();
591589
const int device_id = cuda::current_device();
592590
const int sm_arch_ = cuda::sm_arch(device_id);
593-
bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
591+
bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
594592

595593
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
596594
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
@@ -602,8 +600,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
602600
int64_t actual_b = b;
603601
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
604602
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
605-
// On SM 120, cuDNN support check requires BHSD-like strides with max_seqlen (see fwd).
606-
if (sm_arch_ != 120) {
603+
// On SM8X/SM12X, cuDNN requires BHSD-like strides with max_seqlen at plan build.
604+
if (sm_arch_ >= 90 && sm_arch_ != 120) {
607605
// replace batch size and maximum sequence lengths with maximum token counts
608606
// for query and key/value so the graph is static within each quantization bucket
609607
b = max_b;
@@ -805,7 +803,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
805803
if (use_ragged_stats) {
806804
sdpa_backward_options.set_max_total_seq_len_q(s_q);
807805
}
808-
if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {
806+
if (is_ragged_kv && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120) {
809807
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
810808
}
811809

@@ -1139,10 +1137,13 @@ void fused_attn_arbitrary_seqlen_fwd(
11391137
size_t i = 0;
11401138
if (Aux_CTX_Tensors->size == 0) {
11411139
const auto cudnn_runtime_version = cudnnGetVersion();
1140+
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
1141+
bool use_ragged_stats =
1142+
is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
11421143

11431144
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
11441145
output_S->data.dptr = nullptr;
1145-
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
1146+
if (use_ragged_stats) {
11461147
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
11471148
} else {
11481149
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
@@ -1152,8 +1153,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11521153
if (return_max_logit) {
11531154
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
11541155
output_Max->data.dptr = nullptr;
1155-
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
1156-
(sm_arch_ != 120)) {
1156+
if (use_ragged_stats) {
11571157
output_Max->data.shape = {num_tokens_q, num_attn_heads, 1};
11581158
} else {
11591159
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
965965
use_flash_attention_4 = False
966966

967967
# Filter: QKV layout
968-
if qkv_format == "thd":
968+
if "thd" in (q_format, kv_format):
969969
if pad_between_seqs:
970970
if ( # pylint: disable=too-many-boolean-expressions
971971
use_flash_attention_2 and FlashAttentionUtils.is_installed

0 commit comments

Comments
 (0)