File tree Expand file tree Collapse file tree
transformer_engine/common/fused_attn Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
8686 const auto cudnn_runtime_version = cudnnGetVersion ();
8787 const int device_id = cuda::current_device ();
8888 const int sm_arch_ = cuda::sm_arch (device_id);
89- bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
89+ bool use_ragged_stats =
90+ is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
9091
9192 NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group (qkv_layout);
9293 bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
@@ -588,7 +589,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
588589 const auto cudnn_runtime_version = cudnnGetVersion ();
589590 const int device_id = cuda::current_device ();
590591 const int sm_arch_ = cuda::sm_arch (device_id);
591- bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
592+ bool use_ragged_stats =
593+ is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120 ;
592594
593595 NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group (qkv_layout);
594596 bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
You can’t perform that action at this time.
0 commit comments