Skip to content

Commit 72ea121

Browse files
zhtmikesayakpaul
andauthored
add SP support for flash_varlen_hub backend (#13479)
* add mask support for flash backend * fix test case * refactor test * add protection * fix comment * update according to suggestion * revert change * fix according to claude review * add test converage for QwenImage * add SP support and fix non-contiguous mask for flash_varlen kernel * revert change * Update tests/models/testing_utils/parallelism.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update tests/models/testing_utils/parallelism.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * drop `_padded_to_unpad` * follow `if _parallel_config is None` pattern * rename `attn_mask_2d` * move check to the top * make comment clear * move non-contiguous-attention-mask as default dummy data * revert and update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent d773308 commit 72ea121

4 files changed

Lines changed: 271 additions & 35 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 239 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ class _HubKernelConfig:
352352
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
353353
repo_id="kernels-community/flash-attn2",
354354
function_attr="flash_attn_varlen_func",
355+
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward",
356+
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward",
355357
version=1,
356358
),
357359
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
@@ -636,6 +638,13 @@ def _prepare_for_flash_attn_or_sage_varlen(
636638
return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
637639

638640

641+
def _unpad_to_padded(packed: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
642+
"""scatter a packed `(nnz, ...)` tensor back to padded `(batch_size, seq_len, ...)`."""
643+
output = torch.zeros(batch_size * seq_len, *packed.shape[1:], dtype=packed.dtype, device=packed.device)
644+
output[indices] = packed
645+
return output.view(batch_size, seq_len, *packed.shape[1:])
646+
647+
639648
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
640649
"""
641650
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
@@ -1292,6 +1301,178 @@ def _flash_attention_hub_backward_op(
12921301
return grad_query, grad_key, grad_value
12931302

12941303

1304+
def _flash_varlen_attention_hub_forward_op(
1305+
ctx: torch.autograd.function.FunctionCtx,
1306+
query: torch.Tensor,
1307+
key: torch.Tensor,
1308+
value: torch.Tensor,
1309+
attn_mask: torch.Tensor | None = None,
1310+
dropout_p: float = 0.0,
1311+
is_causal: bool = False,
1312+
scale: float | None = None,
1313+
enable_gqa: bool = False,
1314+
return_lse: bool = False,
1315+
_save_ctx: bool = True,
1316+
_parallel_config: "ParallelConfig" | None = None,
1317+
*,
1318+
window_size: tuple[int, int] = (-1, -1),
1319+
):
1320+
if enable_gqa:
1321+
raise ValueError("`enable_gqa` is not yet supported for flash-attn varlen hub kernels.")
1322+
1323+
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB]
1324+
wrapped_forward_fn = config.wrapped_forward_fn
1325+
wrapped_backward_fn = config.wrapped_backward_fn
1326+
if wrapped_forward_fn is None or wrapped_backward_fn is None:
1327+
raise RuntimeError(
1328+
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and "
1329+
"`_wrapped_flash_attn_varlen_backward` for context parallel execution."
1330+
)
1331+
1332+
if scale is None:
1333+
scale = query.shape[-1] ** (-0.5)
1334+
1335+
softcap = 0.0
1336+
alibi_slopes = None
1337+
deterministic = False
1338+
grad_enabled = any(x.requires_grad for x in (query, key, value))
1339+
1340+
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
1341+
dropout_p = dropout_p if dropout_p > 0 else 1e-30
1342+
1343+
batch_size, seq_len_q, num_heads, _ = query.shape
1344+
_, seq_len_kv, _, _ = key.shape
1345+
1346+
if attn_mask is not None:
1347+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
1348+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = (
1349+
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
1350+
)
1351+
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
1352+
query_packed = query.flatten(0, 1)
1353+
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
1354+
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
1355+
max_seqlen_q = seq_len_q
1356+
else:
1357+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
1358+
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
1359+
)
1360+
query_packed = query.flatten(0, 1)
1361+
key_packed = key.flatten(0, 1)
1362+
value_packed = value.flatten(0, 1)
1363+
seqlens_k = None
1364+
1365+
with torch.set_grad_enabled(grad_enabled):
1366+
out_packed, lse, _, rng_state = wrapped_forward_fn(
1367+
query_packed,
1368+
key_packed,
1369+
value_packed,
1370+
cu_seqlens_q,
1371+
cu_seqlens_k,
1372+
max_seqlen_q,
1373+
max_seqlen_k,
1374+
dropout_p,
1375+
scale,
1376+
is_causal,
1377+
window_size[0],
1378+
window_size[1],
1379+
softcap,
1380+
alibi_slopes,
1381+
return_lse,
1382+
)
1383+
1384+
out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:])
1385+
1386+
if _save_ctx:
1387+
ctx.save_for_backward(
1388+
query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k
1389+
)
1390+
ctx.seqlens_k = seqlens_k # None if unmasked
1391+
ctx.indices_k = indices_k if attn_mask is not None else None
1392+
ctx.max_seqlen_q = max_seqlen_q
1393+
ctx.max_seqlen_k = max_seqlen_k
1394+
ctx.batch_size = batch_size
1395+
ctx.seq_len_q = seq_len_q
1396+
ctx.seq_len_kv = seq_len_kv
1397+
ctx.num_heads = num_heads
1398+
ctx.dropout_p = dropout_p
1399+
ctx.scale = scale
1400+
ctx.is_causal = is_causal
1401+
ctx.window_size = window_size
1402+
ctx.softcap = softcap
1403+
ctx.alibi_slopes = alibi_slopes
1404+
ctx.deterministic = deterministic
1405+
1406+
# (num_heads, batch_size * seq_len_q) -> (batch_size, seq_len_q, num_heads)
1407+
lse_sp = lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous()
1408+
1409+
return (out, lse_sp) if return_lse else out
1410+
1411+
1412+
def _flash_varlen_attention_hub_backward_op(
1413+
ctx: torch.autograd.function.FunctionCtx,
1414+
grad_out: torch.Tensor,
1415+
*args,
1416+
**kwargs,
1417+
):
1418+
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB]
1419+
wrapped_backward_fn = config.wrapped_backward_fn
1420+
if wrapped_backward_fn is None:
1421+
raise RuntimeError(
1422+
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` "
1423+
"for context parallel execution."
1424+
)
1425+
1426+
query_packed, key_packed, value_packed, out_packed, lse, rng_state, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
1427+
1428+
grad_out_packed = grad_out.flatten(0, 1)
1429+
grad_query, grad_key, grad_value = (
1430+
torch.empty_like(query_packed),
1431+
torch.empty_like(key_packed),
1432+
torch.empty_like(value_packed),
1433+
)
1434+
1435+
_ = wrapped_backward_fn(
1436+
grad_out_packed,
1437+
query_packed,
1438+
key_packed,
1439+
value_packed,
1440+
out_packed,
1441+
lse,
1442+
grad_query,
1443+
grad_key,
1444+
grad_value,
1445+
cu_seqlens_q,
1446+
cu_seqlens_k,
1447+
ctx.max_seqlen_q,
1448+
ctx.max_seqlen_k,
1449+
ctx.dropout_p,
1450+
ctx.scale,
1451+
ctx.is_causal,
1452+
ctx.window_size[0],
1453+
ctx.window_size[1],
1454+
ctx.softcap,
1455+
ctx.alibi_slopes,
1456+
ctx.deterministic,
1457+
rng_state,
1458+
)
1459+
1460+
grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:])
1461+
1462+
if ctx.seqlens_k is not None:
1463+
grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
1464+
grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv)
1465+
else:
1466+
grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:])
1467+
grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:])
1468+
1469+
grad_query = grad_query[..., : grad_out.shape[-1]]
1470+
grad_key = grad_key[..., : grad_out.shape[-1]]
1471+
grad_value = grad_value[..., : grad_out.shape[-1]]
1472+
1473+
return grad_query, grad_key, grad_value
1474+
1475+
12951476
def _flash_attention_3_hub_forward_op(
12961477
ctx: torch.autograd.function.FunctionCtx,
12971478
query: torch.Tensor,
@@ -2557,7 +2738,7 @@ def _flash_attention_hub(
25572738
@_AttentionBackendRegistry.register(
25582739
AttentionBackendName.FLASH_VARLEN_HUB,
25592740
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
2560-
supports_context_parallel=False,
2741+
supports_context_parallel=True,
25612742
)
25622743
def _flash_varlen_attention_hub(
25632744
query: torch.Tensor,
@@ -2571,46 +2752,69 @@ def _flash_varlen_attention_hub(
25712752
return_lse: bool = False,
25722753
_parallel_config: "ParallelConfig" | None = None,
25732754
) -> torch.Tensor:
2755+
if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1:
2756+
raise NotImplementedError("`ring_degree > 1` is not yet supported for the FLASH_VARLEN_HUB backend.")
2757+
2758+
lse = None
25742759
batch_size, seq_len_q, _, _ = query.shape
25752760
_, seq_len_kv, _, _ = key.shape
25762761

2577-
if attn_mask is not None:
2578-
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
2579-
2580-
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
2581-
_prepare_for_flash_attn_or_sage_varlen(
2582-
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
2583-
)
2584-
)
2762+
if _parallel_config is None:
2763+
if attn_mask is not None:
2764+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
2765+
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
2766+
_prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device)
2767+
)
2768+
indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten()
2769+
key_packed = key.reshape(-1, *key.shape[2:])[indices_k]
2770+
value_packed = value.reshape(-1, *value.shape[2:])[indices_k]
2771+
else:
2772+
(_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
2773+
_prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device)
2774+
)
2775+
key_packed = key.flatten(0, 1)
2776+
value_packed = value.flatten(0, 1)
25852777

2586-
key_valid, value_valid = [], []
2587-
for b in range(batch_size):
2588-
valid_len = seqlens_k[b]
2589-
key_valid.append(key[b, :valid_len])
2590-
value_valid.append(value[b, :valid_len])
2778+
query_packed = query.flatten(0, 1)
25912779

2592-
query_packed = query.flatten(0, 1)
2593-
key_packed = torch.cat(key_valid, dim=0)
2594-
value_packed = torch.cat(value_valid, dim=0)
2595-
2596-
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
2597-
out = func(
2598-
q=query_packed,
2599-
k=key_packed,
2600-
v=value_packed,
2601-
cu_seqlens_q=cu_seqlens_q,
2602-
cu_seqlens_k=cu_seqlens_k,
2603-
max_seqlen_q=max_seqlen_q,
2604-
max_seqlen_k=max_seqlen_k,
2605-
dropout_p=dropout_p,
2606-
softmax_scale=scale,
2607-
causal=is_causal,
2608-
window_size=window_size,
2609-
return_attn_probs=return_lse,
2610-
)
2611-
out = out.unflatten(0, (batch_size, -1))
2780+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
2781+
out = func(
2782+
q=query_packed,
2783+
k=key_packed,
2784+
v=value_packed,
2785+
cu_seqlens_q=cu_seqlens_q,
2786+
cu_seqlens_k=cu_seqlens_k,
2787+
max_seqlen_q=max_seqlen_q,
2788+
max_seqlen_k=max_seqlen_k,
2789+
dropout_p=dropout_p,
2790+
softmax_scale=scale,
2791+
causal=is_causal,
2792+
window_size=window_size,
2793+
return_attn_probs=return_lse,
2794+
)
2795+
if return_lse:
2796+
out, lse, *_ = out
2797+
out = out.unflatten(0, (batch_size, -1))
2798+
else:
2799+
forward_op = functools.partial(_flash_varlen_attention_hub_forward_op, window_size=window_size)
2800+
out = _templated_context_parallel_attention(
2801+
query,
2802+
key,
2803+
value,
2804+
attn_mask,
2805+
dropout_p,
2806+
is_causal,
2807+
scale,
2808+
False,
2809+
return_lse,
2810+
forward_op=forward_op,
2811+
backward_op=_flash_varlen_attention_hub_backward_op,
2812+
_parallel_config=_parallel_config,
2813+
)
2814+
if return_lse:
2815+
out, lse = out
26122816

2613-
return out
2817+
return (out, lse) if return_lse else out
26142818

26152819

26162820
@_AttentionBackendRegistry.register(

tests/models/testing_utils/parallelism.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
374374
@is_context_parallel
375375
@require_torch_multi_accelerator
376376
class ContextParallelAttentionBackendsTesterMixin:
377+
unsupported_attn_backends: list[str] = []
378+
377379
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
378380
@pytest.mark.parametrize(
379381
"attention_backend",
@@ -383,6 +385,10 @@ class ContextParallelAttentionBackendsTesterMixin:
383385
"flash_hub",
384386
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
385387
),
388+
pytest.param(
389+
"flash_varlen_hub",
390+
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
391+
),
386392
pytest.param(
387393
"_flash_3_hub",
388394
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
@@ -398,9 +404,14 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attention_backen
398404
if getattr(self.model_class, "_cp_plan", None) is None:
399405
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
400406

407+
if attention_backend in self.unsupported_attn_backends:
408+
pytest.skip(f"{attention_backend} is not supported for this model.")
409+
401410
if cp_type == "ring_degree":
402411
if attention_backend == AttentionBackendName.NATIVE:
403412
pytest.skip("Skipping test because ring isn't supported with native attention backend.")
413+
elif attention_backend in ("flash_varlen_hub"):
414+
pytest.skip("`ring_degree` is not yet supported for varlen attention hub kernels.")
404415

405416
if ulysses_anything and "ulysses" not in cp_type:
406417
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")

tests/models/testing_utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
_BF16_REQUIRED_BACKENDS = {
77
AttentionBackendName._NATIVE_CUDNN,
88
AttentionBackendName.FLASH_HUB,
9+
AttentionBackendName.FLASH_VARLEN_HUB,
910
AttentionBackendName._FLASH_3_HUB,
1011
}
1112

tests/models/transformers/test_models_transformer_qwenimage.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
AttentionTesterMixin,
2626
BaseModelTesterConfig,
2727
BitsAndBytesTesterMixin,
28+
ContextParallelAttentionBackendsTesterMixin,
2829
ContextParallelTesterMixin,
2930
LoraHotSwappingForModelTesterMixin,
3031
LoraTesterMixin,
@@ -253,6 +254,25 @@ class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig,
253254
"""Context Parallel inference tests for QwenImage Transformer."""
254255

255256

257+
class TestQwenImageTransformerContextParallelAttnBackends(
258+
QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
259+
):
260+
"""Context Parallel inference x attention backends tests for QwenImage Transformer"""
261+
262+
# QwenImage always passes a joint attention mask (text + image), which flash_hub and
263+
# _flash_3_hub do not support.
264+
unsupported_attn_backends = ["flash_hub", "_flash_3_hub"]
265+
266+
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
267+
inputs = super().get_dummy_inputs(batch_size=batch_size)
268+
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"]
269+
encoder_hidden_states_mask[:, 1] = 0
270+
encoder_hidden_states_mask[:, 3] = 0
271+
encoder_hidden_states_mask[:, 5:] = 0
272+
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
273+
return inputs
274+
275+
256276
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
257277
"""LoRA adapter tests for QwenImage Transformer."""
258278

0 commit comments

Comments
 (0)