@@ -1177,6 +1177,103 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
11771177 return x
11781178
11791179
1180+ def _all_to_all_dim_exchange (x : torch .Tensor , scatter_idx : int = 2 , gather_idx : int = 1 , group = None ) -> torch .Tensor :
1181+ """
1182+ Perform dimension sharding / reassembly across processes using _all_to_all_single.
1183+
1184+ This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
1185+ head dimension flexibly by accepting scatter_idx and gather_idx.
1186+
1187+ Args:
1188+ x (torch.Tensor):
1189+ Input tensor. Expected shapes:
1190+ - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
1191+ - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
1192+ scatter_idx (int) :
1193+ Dimension along which the tensor is partitioned before all-to-all.
1194+ gather_idx (int):
1195+ Dimension along which the output is reassembled after all-to-all.
1196+ group :
1197+ Distributed process group for the Ulysses group.
1198+
1199+ Returns:
1200+ torch.Tensor: Tensor with globally exchanged dimensions.
1201+ - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
1202+ - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
1203+ """
1204+ group_world_size = torch .distributed .get_world_size (group )
1205+
1206+ if scatter_idx == 2 and gather_idx == 1 :
1207+ # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
1208+ # dimension and scatters head dimension
1209+ batch_size , seq_len_local , num_heads , head_dim = x .shape
1210+ seq_len = seq_len_local * group_world_size
1211+ num_heads_local = num_heads // group_world_size
1212+
1213+ # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
1214+ x_temp = (
1215+ x .reshape (batch_size , seq_len_local , group_world_size , num_heads_local , head_dim )
1216+ .transpose (0 , 2 )
1217+ .contiguous ()
1218+ )
1219+
1220+ if group_world_size > 1 :
1221+ out = _all_to_all_single (x_temp , group = group )
1222+ else :
1223+ out = x_temp
1224+ # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
1225+ out = out .reshape (seq_len , batch_size , num_heads_local , head_dim ).permute (1 , 0 , 2 , 3 ).contiguous ()
1226+ out = out .reshape (batch_size , seq_len , num_heads_local , head_dim )
1227+ return out
1228+ elif scatter_idx == 1 and gather_idx == 2 :
1229+ # Used after ulysses sequence parallel in unified SP. gathers the head dimension
1230+ # scatters back the sequence dimension.
1231+ batch_size , seq_len , num_heads_local , head_dim = x .shape
1232+ num_heads = num_heads_local * group_world_size
1233+ seq_len_local = seq_len // group_world_size
1234+
1235+ # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1236+ x_temp = (
1237+ x .reshape (batch_size , group_world_size , seq_len_local , num_heads_local , head_dim )
1238+ .permute (1 , 3 , 2 , 0 , 4 )
1239+ .reshape (group_world_size , num_heads_local , seq_len_local , batch_size , head_dim )
1240+ )
1241+
1242+ if group_world_size > 1 :
1243+ output = _all_to_all_single (x_temp , group )
1244+ else :
1245+ output = x_temp
1246+ output = output .reshape (num_heads , seq_len_local , batch_size , head_dim ).transpose (0 , 2 ).contiguous ()
1247+ output = output .reshape (batch_size , seq_len_local , num_heads , head_dim )
1248+ return output
1249+ else :
1250+ raise ValueError ("Invalid scatter/gather indices for _all_to_all_dim_exchange." )
1251+
1252+
1253+ class SeqAllToAllDim (torch .autograd .Function ):
1254+ """
1255+ all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
1256+ for more info.
1257+ """
1258+
1259+ @staticmethod
1260+ def forward (ctx , group , input , scatter_id = 2 , gather_id = 1 ):
1261+ ctx .group = group
1262+ ctx .scatter_id = scatter_id
1263+ ctx .gather_id = gather_id
1264+ return _all_to_all_dim_exchange (input , scatter_id , gather_id , group )
1265+
1266+ @staticmethod
1267+ def backward (ctx , grad_outputs ):
1268+ grad_input = SeqAllToAllDim .apply (
1269+ ctx .group ,
1270+ grad_outputs ,
1271+ ctx .gather_id , # reversed
1272+ ctx .scatter_id , # reversed
1273+ )
1274+ return (None , grad_input , None , None )
1275+
1276+
11801277class TemplatedRingAttention (torch .autograd .Function ):
11811278 @staticmethod
11821279 def forward (
@@ -1237,7 +1334,10 @@ def forward(
12371334 out = out .to (torch .float32 )
12381335 lse = lse .to (torch .float32 )
12391336
1240- lse = lse .unsqueeze (- 1 )
1337+ # Refer to:
1338+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1339+ if is_torch_version ("<" , "2.9.0" ):
1340+ lse = lse .unsqueeze (- 1 )
12411341 if prev_out is not None :
12421342 out = prev_out - torch .nn .functional .sigmoid (lse - prev_lse ) * (prev_out - out )
12431343 lse = prev_lse - torch .nn .functional .logsigmoid (prev_lse - lse )
@@ -1298,7 +1398,7 @@ def backward(
12981398
12991399 grad_query , grad_key , grad_value = (x .to (grad_out .dtype ) for x in (grad_query , grad_key , grad_value ))
13001400
1301- return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
1401+ return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None , None
13021402
13031403
13041404class TemplatedUlyssesAttention (torch .autograd .Function ):
@@ -1393,7 +1493,69 @@ def backward(
13931493 x .flatten (0 , 1 ).permute (1 , 2 , 0 , 3 ).contiguous () for x in (grad_query , grad_key , grad_value )
13941494 )
13951495
1396- return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None
1496+ return grad_query , grad_key , grad_value , None , None , None , None , None , None , None , None , None
1497+
1498+
1499+ def _templated_unified_attention (
1500+ query : torch .Tensor ,
1501+ key : torch .Tensor ,
1502+ value : torch .Tensor ,
1503+ attn_mask : Optional [torch .Tensor ],
1504+ dropout_p : float ,
1505+ is_causal : bool ,
1506+ scale : Optional [float ],
1507+ enable_gqa : bool ,
1508+ return_lse : bool ,
1509+ forward_op ,
1510+ backward_op ,
1511+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1512+ scatter_idx : int = 2 ,
1513+ gather_idx : int = 1 ,
1514+ ):
1515+ """
1516+ Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
1517+ """
1518+ ulysses_mesh = _parallel_config .context_parallel_config ._ulysses_mesh
1519+ ulysses_group = ulysses_mesh .get_group ()
1520+
1521+ query = SeqAllToAllDim .apply (ulysses_group , query , scatter_idx , gather_idx )
1522+ key = SeqAllToAllDim .apply (ulysses_group , key , scatter_idx , gather_idx )
1523+ value = SeqAllToAllDim .apply (ulysses_group , value , scatter_idx , gather_idx )
1524+ out = TemplatedRingAttention .apply (
1525+ query ,
1526+ key ,
1527+ value ,
1528+ attn_mask ,
1529+ dropout_p ,
1530+ is_causal ,
1531+ scale ,
1532+ enable_gqa ,
1533+ return_lse ,
1534+ forward_op ,
1535+ backward_op ,
1536+ _parallel_config ,
1537+ )
1538+ if return_lse :
1539+ context_layer , lse , * _ = out
1540+ else :
1541+ context_layer = out
1542+ # context_layer is of shape (B, S, H_LOCAL, D)
1543+ output = SeqAllToAllDim .apply (
1544+ ulysses_group ,
1545+ context_layer ,
1546+ gather_idx ,
1547+ scatter_idx ,
1548+ )
1549+ if return_lse :
1550+ # lse is of shape (B, S, H_LOCAL, 1)
1551+ # Refer to:
1552+ # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1553+ if is_torch_version ("<" , "2.9.0" ):
1554+ lse = lse .unsqueeze (- 1 ) # (B, S, H_LOCAL, 1)
1555+ lse = SeqAllToAllDim .apply (ulysses_group , lse , gather_idx , scatter_idx )
1556+ lse = lse .squeeze (- 1 )
1557+ return (output , lse )
1558+ return output
13971559
13981560
13991561def _templated_context_parallel_attention (
@@ -1419,7 +1581,25 @@ def _templated_context_parallel_attention(
14191581 raise ValueError ("GQA is not yet supported for templated attention." )
14201582
14211583 # TODO: add support for unified attention with ring/ulysses degree both being > 1
1422- if _parallel_config .context_parallel_config .ring_degree > 1 :
1584+ if (
1585+ _parallel_config .context_parallel_config .ring_degree > 1
1586+ and _parallel_config .context_parallel_config .ulysses_degree > 1
1587+ ):
1588+ return _templated_unified_attention (
1589+ query ,
1590+ key ,
1591+ value ,
1592+ attn_mask ,
1593+ dropout_p ,
1594+ is_causal ,
1595+ scale ,
1596+ enable_gqa ,
1597+ return_lse ,
1598+ forward_op ,
1599+ backward_op ,
1600+ _parallel_config ,
1601+ )
1602+ elif _parallel_config .context_parallel_config .ring_degree > 1 :
14231603 return TemplatedRingAttention .apply (
14241604 query ,
14251605 key ,
0 commit comments