Skip to content

Commit 514d032

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 351cd84 commit 514d032

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
8686
const auto cudnn_runtime_version = cudnnGetVersion();
8787
const int device_id = cuda::current_device();
8888
const int sm_arch_ = cuda::sm_arch(device_id);
89-
bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
89+
bool use_ragged_stats =
90+
is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
9091

9192
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
9293
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
@@ -588,7 +589,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
588589
const auto cudnn_runtime_version = cudnnGetVersion();
589590
const int device_id = cuda::current_device();
590591
const int sm_arch_ = cuda::sm_arch(device_id);
591-
bool use_ragged_stats = is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
592+
bool use_ragged_stats =
593+
is_ragged_q && sm_arch_ >= 90 && cudnn_runtime_version >= 90600 && sm_arch_ != 120;
592594

593595
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
594596
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);

0 commit comments

Comments
 (0)