Skip to content

Commit 831cb06

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent db7c09e commit 831cb06

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
170170
from flash_attn.cute.interface import _flash_attn_fwd as _flash_attn_fwd_v4
171171
from flash_attn.cute.interface import _flash_attn_bwd as _flash_attn_bwd_v4
172+
172173
# flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
173174
fa_utils.set_flash_attention_4_params()
174175

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,8 +1128,9 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
11281128
use_flash_attention_2 = False
11291129
if use_flash_attention_3 and deterministic and FlashAttentionUtils.v3_is_installed:
11301130
if head_dim_qk >= 256:
1131-
logger.debug("Disabling FlashAttention 3 for deterministic execution with "
1132-
"head_dim_qk >= 256.")
1131+
logger.debug(
1132+
"Disabling FlashAttention 3 for deterministic execution with head_dim_qk >= 256."
1133+
)
11331134
use_flash_attention_3 = False
11341135
if use_fused_attention and deterministic:
11351136
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:

0 commit comments

Comments
 (0)