Skip to content

Commit 3105848

Browse files
authored
[attention backends] use dedicated wrappers from fa3 for cp. (#13165)
* use dedicated wrappers from fa3 for cp. * up
1 parent aac94be commit 3105848

File tree

1 file changed

+91
-66
lines changed

1 file changed

+91
-66
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 91 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,11 @@ class _HubKernelConfig:
329329
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
330330
# TODO: temporary revision for now. Remove when merged upstream into `main`.
331331
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
332-
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
332+
repo_id="kernels-community/flash-attn3",
333+
function_attr="flash_attn_func",
334+
revision="fake-ops-return-probs",
335+
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
336+
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
333337
),
334338
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
335339
repo_id="kernels-community/flash-attn3",
@@ -1290,36 +1294,62 @@ def _flash_attention_3_hub_forward_op(
12901294
if enable_gqa:
12911295
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
12921296

1293-
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
1294-
out = func(
1295-
q=query,
1296-
k=key,
1297-
v=value,
1298-
softmax_scale=scale,
1297+
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
1298+
wrapped_forward_fn = config.wrapped_forward_fn
1299+
if wrapped_forward_fn is None:
1300+
raise RuntimeError(
1301+
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
1302+
"for context parallel execution."
1303+
)
1304+
1305+
if scale is None:
1306+
scale = query.shape[-1] ** (-0.5)
1307+
1308+
out, softmax_lse, *_ = wrapped_forward_fn(
1309+
query,
1310+
key,
1311+
value,
1312+
None,
1313+
None, # k_new, v_new
1314+
None, # qv
1315+
None, # out
1316+
None,
1317+
None,
1318+
None, # cu_seqlens_q/k/k_new
1319+
None,
1320+
None, # seqused_q/k
1321+
None,
1322+
None, # max_seqlen_q/k
1323+
None,
1324+
None,
1325+
None, # page_table, kv_batch_idx, leftpad_k
1326+
None,
1327+
None,
1328+
None, # rotary_cos/sin, seqlens_rotary
1329+
None,
1330+
None,
1331+
None, # q_descale, k_descale, v_descale
1332+
scale,
12991333
causal=is_causal,
1300-
qv=None,
1301-
q_descale=None,
1302-
k_descale=None,
1303-
v_descale=None,
1304-
window_size=window_size,
1334+
window_size_left=window_size[0],
1335+
window_size_right=window_size[1],
1336+
attention_chunk=0,
13051337
softcap=softcap,
13061338
num_splits=num_splits,
13071339
pack_gqa=pack_gqa,
1308-
deterministic=deterministic,
13091340
sm_margin=sm_margin,
1310-
return_attn_probs=return_lse,
13111341
)
13121342

1313-
lse = None
1314-
if return_lse:
1315-
out, lse = out
1316-
lse = lse.permute(0, 2, 1).contiguous()
1343+
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
13171344

13181345
if _save_ctx:
1319-
ctx.save_for_backward(query, key, value)
1346+
ctx.save_for_backward(query, key, value, out, softmax_lse)
13201347
ctx.scale = scale
13211348
ctx.is_causal = is_causal
1322-
ctx._hub_kernel = func
1349+
ctx.window_size = window_size
1350+
ctx.softcap = softcap
1351+
ctx.deterministic = deterministic
1352+
ctx.sm_margin = sm_margin
13231353

13241354
return (out, lse) if return_lse else out
13251355

@@ -1328,55 +1358,50 @@ def _flash_attention_3_hub_backward_op(
13281358
ctx: torch.autograd.function.FunctionCtx,
13291359
grad_out: torch.Tensor,
13301360
*args,
1331-
window_size: tuple[int, int] = (-1, -1),
1332-
softcap: float = 0.0,
1333-
num_splits: int = 1,
1334-
pack_gqa: bool | None = None,
1335-
deterministic: bool = False,
1336-
sm_margin: int = 0,
1361+
**kwargs,
13371362
):
1338-
query, key, value = ctx.saved_tensors
1339-
kernel_fn = ctx._hub_kernel
1340-
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
1341-
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
1342-
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
1343-
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
1344-
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
1345-
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
1346-
with torch.enable_grad():
1347-
query_r = query.detach().requires_grad_(True)
1348-
key_r = key.detach().requires_grad_(True)
1349-
value_r = value.detach().requires_grad_(True)
1350-
1351-
out = kernel_fn(
1352-
q=query_r,
1353-
k=key_r,
1354-
v=value_r,
1355-
softmax_scale=ctx.scale,
1356-
causal=ctx.is_causal,
1357-
qv=None,
1358-
q_descale=None,
1359-
k_descale=None,
1360-
v_descale=None,
1361-
window_size=window_size,
1362-
softcap=softcap,
1363-
num_splits=num_splits,
1364-
pack_gqa=pack_gqa,
1365-
deterministic=deterministic,
1366-
sm_margin=sm_margin,
1367-
return_attn_probs=False,
1368-
)
1369-
if isinstance(out, tuple):
1370-
out = out[0]
1371-
1372-
grad_query, grad_key, grad_value = torch.autograd.grad(
1373-
out,
1374-
(query_r, key_r, value_r),
1375-
grad_out,
1376-
retain_graph=False,
1377-
allow_unused=False,
1363+
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
1364+
wrapped_backward_fn = config.wrapped_backward_fn
1365+
if wrapped_backward_fn is None:
1366+
raise RuntimeError(
1367+
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
1368+
"for context parallel execution."
13781369
)
13791370

1371+
query, key, value, out, softmax_lse = ctx.saved_tensors
1372+
grad_query = torch.empty_like(query)
1373+
grad_key = torch.empty_like(key)
1374+
grad_value = torch.empty_like(value)
1375+
1376+
wrapped_backward_fn(
1377+
grad_out,
1378+
query,
1379+
key,
1380+
value,
1381+
out,
1382+
softmax_lse,
1383+
None,
1384+
None, # cu_seqlens_q, cu_seqlens_k
1385+
None,
1386+
None, # seqused_q, seqused_k
1387+
None,
1388+
None, # max_seqlen_q, max_seqlen_k
1389+
grad_query,
1390+
grad_key,
1391+
grad_value,
1392+
ctx.scale,
1393+
ctx.is_causal,
1394+
ctx.window_size[0],
1395+
ctx.window_size[1],
1396+
ctx.softcap,
1397+
ctx.deterministic,
1398+
ctx.sm_margin,
1399+
)
1400+
1401+
grad_query = grad_query[..., : grad_out.shape[-1]]
1402+
grad_key = grad_key[..., : grad_out.shape[-1]]
1403+
grad_value = grad_value[..., : grad_out.shape[-1]]
1404+
13801405
return grad_query, grad_key, grad_value
13811406

13821407

0 commit comments

Comments
 (0)