Skip to content

Commit 9d68742

Browse files
BissmellaKarthikSundar2002github-actions[bot]sayakpaul
authored
Add Unified Sequence Parallel attention (#12693)
* initial scheme of unified-sp * initial all_to_all_double * bug fixes, added cmnts * unified attention prototype done * remove raising value error in contextParallelConfig to enable unified attention * bug fix * feat: Adds Test for Unified SP Attention and Fixes a bug in Template Ring Attention * bug fix, lse calculation, testing bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix * addressing comments * sequence parallelsim bug fixes * code format fixes * Apply style fixes * code formatting fix * added unified attention docs and removed test file * Apply style fixes * tip for unified attention in docs at distributed_inference.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update distributed_inference.md, adding benchmarks Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/training/distributed_inference.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * function name fix * fixed benchmark in docs --------- Co-authored-by: KarthikSundar2002 <karthiksundar30092002@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent f1a93c7 commit 9d68742

File tree

3 files changed

+212
-8
lines changed

3 files changed

+212
-8
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,31 @@ pipeline = DiffusionPipeline.from_pretrained(
333333
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
334334
).to(device)
335335
```
336+
### Unified Attention
337+
338+
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
339+
340+
This hybrid approach leverages the strengths of both methods:
341+
- **Ulysses Attention** efficiently parallelizes across attention heads
342+
- **Ring Attention** handles very long sequences with minimal memory overhead
343+
- Together, they enable 2D parallelization across both heads and sequence dimensions
344+
345+
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
346+
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
347+
348+
```py
349+
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
350+
```
351+
352+
> [!TIP]
353+
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
354+
355+
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
356+
357+
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
358+
|--------------------|------------------|-------------|------------------|
359+
| ulysses | 6670.789 | 7.50 | 33.85 |
360+
| ring | 13076.492 | 3.82 | 56.02 |
361+
| unified_balanced | 11068.705 | 4.52 | 33.85 |
362+
363+
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to number of attention-heads, a limitation that is solved by unified attention.

src/diffusers/models/_modeling_parallel.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ def __post_init__(self):
9090
)
9191
if self.ring_degree < 1 or self.ulysses_degree < 1:
9292
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
93-
if self.ring_degree > 1 and self.ulysses_degree > 1:
94-
raise ValueError(
95-
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
96-
)
9793
if self.rotate_method != "allgather":
9894
raise NotImplementedError(
9995
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

src/diffusers/models/attention_dispatch.py

Lines changed: 184 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
11771177
return x
11781178

11791179

1180+
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
1181+
"""
1182+
Perform dimension sharding / reassembly across processes using _all_to_all_single.
1183+
1184+
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
1185+
head dimension flexibly by accepting scatter_idx and gather_idx.
1186+
1187+
Args:
1188+
x (torch.Tensor):
1189+
Input tensor. Expected shapes:
1190+
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
1191+
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
1192+
scatter_idx (int) :
1193+
Dimension along which the tensor is partitioned before all-to-all.
1194+
gather_idx (int):
1195+
Dimension along which the output is reassembled after all-to-all.
1196+
group :
1197+
Distributed process group for the Ulysses group.
1198+
1199+
Returns:
1200+
torch.Tensor: Tensor with globally exchanged dimensions.
1201+
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
1202+
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
1203+
"""
1204+
group_world_size = torch.distributed.get_world_size(group)
1205+
1206+
if scatter_idx == 2 and gather_idx == 1:
1207+
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1208+
# dimension and scatters head dimension
1209+
batch_size, seq_len_local, num_heads, head_dim = x.shape
1210+
seq_len = seq_len_local * group_world_size
1211+
num_heads_local = num_heads // group_world_size
1212+
1213+
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1214+
x_temp = (
1215+
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
1216+
.transpose(0, 2)
1217+
.contiguous()
1218+
)
1219+
1220+
if group_world_size > 1:
1221+
out = _all_to_all_single(x_temp, group=group)
1222+
else:
1223+
out = x_temp
1224+
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
1225+
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
1226+
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
1227+
return out
1228+
elif scatter_idx == 1 and gather_idx == 2:
1229+
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
1230+
# scatters back the sequence dimension.
1231+
batch_size, seq_len, num_heads_local, head_dim = x.shape
1232+
num_heads = num_heads_local * group_world_size
1233+
seq_len_local = seq_len // group_world_size
1234+
1235+
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1236+
x_temp = (
1237+
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
1238+
.permute(1, 3, 2, 0, 4)
1239+
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
1240+
)
1241+
1242+
if group_world_size > 1:
1243+
output = _all_to_all_single(x_temp, group)
1244+
else:
1245+
output = x_temp
1246+
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
1247+
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
1248+
return output
1249+
else:
1250+
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
1251+
1252+
1253+
class SeqAllToAllDim(torch.autograd.Function):
1254+
"""
1255+
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
1256+
for more info.
1257+
"""
1258+
1259+
@staticmethod
1260+
def forward(ctx, group, input, scatter_id=2, gather_id=1):
1261+
ctx.group = group
1262+
ctx.scatter_id = scatter_id
1263+
ctx.gather_id = gather_id
1264+
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
1265+
1266+
@staticmethod
1267+
def backward(ctx, grad_outputs):
1268+
grad_input = SeqAllToAllDim.apply(
1269+
ctx.group,
1270+
grad_outputs,
1271+
ctx.gather_id, # reversed
1272+
ctx.scatter_id, # reversed
1273+
)
1274+
return (None, grad_input, None, None)
1275+
1276+
11801277
class TemplatedRingAttention(torch.autograd.Function):
11811278
@staticmethod
11821279
def forward(
@@ -1237,7 +1334,10 @@ def forward(
12371334
out = out.to(torch.float32)
12381335
lse = lse.to(torch.float32)
12391336

1240-
lse = lse.unsqueeze(-1)
1337+
# Refer to:
1338+
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1339+
if is_torch_version("<", "2.9.0"):
1340+
lse = lse.unsqueeze(-1)
12411341
if prev_out is not None:
12421342
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
12431343
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1298,7 +1398,7 @@ def backward(
12981398

12991399
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
13001400

1301-
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
1401+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
13021402

13031403

13041404
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1393,7 +1493,69 @@ def backward(
13931493
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
13941494
)
13951495

1396-
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
1496+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
1497+
1498+
1499+
def _templated_unified_attention(
1500+
query: torch.Tensor,
1501+
key: torch.Tensor,
1502+
value: torch.Tensor,
1503+
attn_mask: Optional[torch.Tensor],
1504+
dropout_p: float,
1505+
is_causal: bool,
1506+
scale: Optional[float],
1507+
enable_gqa: bool,
1508+
return_lse: bool,
1509+
forward_op,
1510+
backward_op,
1511+
_parallel_config: Optional["ParallelConfig"] = None,
1512+
scatter_idx: int = 2,
1513+
gather_idx: int = 1,
1514+
):
1515+
"""
1516+
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
1517+
"""
1518+
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
1519+
ulysses_group = ulysses_mesh.get_group()
1520+
1521+
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
1522+
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
1523+
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
1524+
out = TemplatedRingAttention.apply(
1525+
query,
1526+
key,
1527+
value,
1528+
attn_mask,
1529+
dropout_p,
1530+
is_causal,
1531+
scale,
1532+
enable_gqa,
1533+
return_lse,
1534+
forward_op,
1535+
backward_op,
1536+
_parallel_config,
1537+
)
1538+
if return_lse:
1539+
context_layer, lse, *_ = out
1540+
else:
1541+
context_layer = out
1542+
# context_layer is of shape (B, S, H_LOCAL, D)
1543+
output = SeqAllToAllDim.apply(
1544+
ulysses_group,
1545+
context_layer,
1546+
gather_idx,
1547+
scatter_idx,
1548+
)
1549+
if return_lse:
1550+
# lse is of shape (B, S, H_LOCAL, 1)
1551+
# Refer to:
1552+
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1553+
if is_torch_version("<", "2.9.0"):
1554+
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
1555+
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
1556+
lse = lse.squeeze(-1)
1557+
return (output, lse)
1558+
return output
13971559

13981560

13991561
def _templated_context_parallel_attention(
@@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention(
14191581
raise ValueError("GQA is not yet supported for templated attention.")
14201582

14211583
# TODO: add support for unified attention with ring/ulysses degree both being > 1
1422-
if _parallel_config.context_parallel_config.ring_degree > 1:
1584+
if (
1585+
_parallel_config.context_parallel_config.ring_degree > 1
1586+
and _parallel_config.context_parallel_config.ulysses_degree > 1
1587+
):
1588+
return _templated_unified_attention(
1589+
query,
1590+
key,
1591+
value,
1592+
attn_mask,
1593+
dropout_p,
1594+
is_causal,
1595+
scale,
1596+
enable_gqa,
1597+
return_lse,
1598+
forward_op,
1599+
backward_op,
1600+
_parallel_config,
1601+
)
1602+
elif _parallel_config.context_parallel_config.ring_degree > 1:
14231603
return TemplatedRingAttention.apply(
14241604
query,
14251605
key,

0 commit comments

Comments
 (0)