@@ -86,7 +86,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
8686 const auto cudnn_runtime_version = cudnnGetVersion ();
8787 const int device_id = cuda::current_device ();
8888 const int sm_arch_ = cuda::sm_arch (device_id);
89- bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
89+ bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
9090
9191 NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group (qkv_layout);
9292 bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
@@ -98,10 +98,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
9898 int64_t actual_b = b;
9999 if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600 ) {
100100 NVTE_CHECK (is_padding, " Ragged QKV input requires padding or padding_causal mask!" );
101- // On SM 120, cuDNN support check treats layouts with stride[0] > dim[1]*dim[2]*dim[3]
102- // as interleaved and rejects them. Use BHSD-like dimensions/strides with max_seqlen at plan build
103- // so the check passes; ragged offset still provides variable-length boundaries.
104- if (sm_arch_ != 120 ) {
101+ // On SM8X/SM12X, cuDNN requires BHSD-like strides with max_seqlen at plan build.
102+ if (sm_arch_ >= 90 && sm_arch_ != 120 ) {
105103 // replace batch size and maximum sequence lengths with maximum token counts
106104 // for query and key/value so the graph is static within each quantization bucket
107105 b = max_b;
@@ -385,7 +383,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
385383 }
386384
387385 Stats->set_output (true ).set_data_type (fe::DataType_t::FLOAT ).set_dim ({b, h, s_q, 1 });
388- if (is_ragged_q && cudnn_runtime_version >= 90600 ) {
386+ if (use_ragged_stats ) {
389387 Stats->set_stride ({h * s_q, 1 , h, 1 }).set_ragged_offset (offset_stats);
390388 } else {
391389 Stats->set_stride ({h * s_q, s_q, 1 , 1 });
@@ -590,7 +588,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
590588 const auto cudnn_runtime_version = cudnnGetVersion ();
591589 const int device_id = cuda::current_device ();
592590 const int sm_arch_ = cuda::sm_arch (device_id);
593- bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
591+ bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
594592
595593 NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group (qkv_layout);
596594 bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
@@ -602,8 +600,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
602600 int64_t actual_b = b;
603601 if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600 ) {
604602 NVTE_CHECK (is_padding, " Ragged QKV input requires padding or padding_causal mask!" );
605- // On SM 120 , cuDNN support check requires BHSD-like strides with max_seqlen (see fwd) .
606- if (sm_arch_ != 120 ) {
603+ // On SM8X/SM12X , cuDNN requires BHSD-like strides with max_seqlen at plan build .
604+ if (sm_arch_ >= 90 && sm_arch_ != 120 ) {
607605 // replace batch size and maximum sequence lengths with maximum token counts
608606 // for query and key/value so the graph is static within each quantization bucket
609607 b = max_b;
@@ -805,7 +803,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
805803 if (use_ragged_stats) {
806804 sdpa_backward_options.set_max_total_seq_len_q (s_q);
807805 }
808- if (is_ragged_kv && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ) {
806+ if (is_ragged_kv && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ) {
809807 sdpa_backward_options.set_max_total_seq_len_kv (s_kv);
810808 }
811809
@@ -1139,10 +1137,13 @@ void fused_attn_arbitrary_seqlen_fwd(
11391137 size_t i = 0 ;
11401138 if (Aux_CTX_Tensors->size == 0 ) {
11411139 const auto cudnn_runtime_version = cudnnGetVersion ();
1140+ bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD );
1141+ bool use_ragged_stats =
1142+ is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
11421143
11431144 Tensor *output_S = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [i++]);
11441145 output_S->data .dptr = nullptr ;
1145- if (q_format == NVTE_QKV_Format:: NVTE_THD && cudnn_runtime_version >= 90600 ) {
1146+ if (use_ragged_stats ) {
11461147 output_S->data .shape = {num_tokens_q, num_attn_heads, 1 };
11471148 } else {
11481149 output_S->data .shape = {batch, num_attn_heads, max_seqlen_q, 1 };
@@ -1152,8 +1153,7 @@ void fused_attn_arbitrary_seqlen_fwd(
11521153 if (return_max_logit) {
11531154 Tensor *output_Max = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [i++]);
11541155 output_Max->data .dptr = nullptr ;
1155- if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600 ) &&
1156- (sm_arch_ != 120 )) {
1156+ if (use_ragged_stats) {
11571157 output_Max->data .shape = {num_tokens_q, num_attn_heads, 1 };
11581158 } else {
11591159 output_Max->data .shape = {batch, num_attn_heads, max_seqlen_q, 1 };
0 commit comments