Skip to content

Commit ae72f97

Browse files
committed
updates
1 parent 195e52c commit ae72f97

1 file changed

Lines changed: 8 additions & 12 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class AttentionBackendName(str, Enum):
172172
_FLASH_3 = "_flash_3"
173173
_FLASH_VARLEN_3 = "_flash_varlen_3"
174174
_FLASH_3_HUB = "_flash_3_hub"
175-
_FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub"
175+
_FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub"
176176

177177
# `aiter`
178178
AITER = "aiter"
@@ -264,10 +264,10 @@ class _HubKernelConfig:
264264
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
265265
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
266266
),
267-
AttentionBackendName._FLASH_VARLEN_3_HUB: _HubKernelConfig(
267+
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
268268
repo_id="kernels-community/flash-attn3",
269269
function_attr="flash_attn_varlen_func",
270-
revision="fake-ops-return-probs",
270+
# revision="fake-ops-return-probs",
271271
),
272272
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
273273
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
@@ -438,7 +438,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
438438
AttentionBackendName.FLASH_HUB,
439439
AttentionBackendName.FLASH_VARLEN_HUB,
440440
AttentionBackendName._FLASH_3_HUB,
441-
AttentionBackendName._FLASH_VARLEN_3_HUB,
441+
AttentionBackendName._FLASH_3_VARLEN_HUB,
442442
AttentionBackendName.SAGE_HUB,
443443
]:
444444
if not is_kernels_available():
@@ -1552,7 +1552,6 @@ def _flash_attention_3(
15521552
softmax_scale=scale,
15531553
causal=is_causal,
15541554
)
1555-
out = _maybe_unflatten_attention_heads(out, query)
15561555
return (out, lse) if return_lse else out
15571556

15581557

@@ -1597,17 +1596,15 @@ def _flash_attention_3_hub(
15971596
)
15981597
# When `return_attn_probs` is True, the above returns a tuple of
15991598
# actual outputs and lse.
1600-
if return_attn_probs:
1601-
return (_maybe_unflatten_attention_heads(out[0], query), out[1])
1602-
return _maybe_unflatten_attention_heads(out, query)
1599+
return (out[0], out[1]) if return_attn_probs else out
16031600

16041601

16051602
@_AttentionBackendRegistry.register(
1606-
AttentionBackendName._FLASH_VARLEN_3_HUB,
1603+
AttentionBackendName._FLASH_3_VARLEN_HUB,
16071604
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
16081605
supports_context_parallel=False,
16091606
)
1610-
def _flash_varlen_attention_3_hub(
1607+
def _flash_attention_3_varlen_hub(
16111608
query: torch.Tensor,
16121609
key: torch.Tensor,
16131610
value: torch.Tensor,
@@ -1639,7 +1636,7 @@ def _flash_varlen_attention_3_hub(
16391636
key_packed = torch.cat(key_valid, dim=0)
16401637
value_packed = torch.cat(value_valid, dim=0)
16411638

1642-
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_VARLEN_3_HUB].kernel_fn
1639+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
16431640
out, lse, *_ = func(
16441641
q=query_packed,
16451642
k=key_packed,
@@ -1652,7 +1649,6 @@ def _flash_varlen_attention_3_hub(
16521649
causal=is_causal,
16531650
)
16541651
out = out.unflatten(0, (batch_size, -1))
1655-
out = _maybe_unflatten_attention_heads(out, query)
16561652

16571653
return (out, lse) if return_lse else out
16581654

0 commit comments

Comments
 (0)