Skip to content
34 changes: 34 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,40 @@ We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulys

From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.


### Ring Anything Attention

The default Ring Attention requires the sequence length of hidden states to be evenly divisible across the ring degree. Ring Anything Attention is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention.
Comment thread
sayakpaul marked this conversation as resolved.
Outdated

[`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`].
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`].
Ring Anything Attention is not supported by Unified Attention. Set `ring_degree > 1` and `ring_anything=True` to enable Ring Anything Attention.


```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2, ring_anything=True))
```

> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.
Comment thread
sayakpaul marked this conversation as resolved.
Outdated

> [!NOTE]
> Backward is not implemented yet; this mode is currently inference-only.
> `attn_mask` must be `None`; non-None attention masks are not supported.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be better after the first paragraph

Suggested change
> [!NOTE]
> Backward is not implemented yet; this mode is currently inference-only.
> `attn_mask` must be `None`; non-None attention masks are not supported.
> [!NOTE]
> Ring Anything Attention only currently supports inference and non-`None` attention masks aren't supported. `attn_mask` must be `None`.


We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows:
See the benchmarks below on a node of 4 RTX 4090 (48GB) GPUs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! I’ve incorporated all the documentation refinements (including the note formatting and benchmark descriptions) in the latest commit.


| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|--------------------|------------------|-------------|------------------|------------|
| ulysses | 259.07 | 3.86 | 33.83 | 1024x1024 |
| ring | 338.98 | 2.95 | 33.83 | 1024x1024 |
| unified_balanced | 321.54 | 3.11 | 33.83 | 1024x1024 |
| ulysses_anything | 259.07 | 3.86 | 33.83 | 1024x1024 |
| ring_anything | 340.14 | 2.94 | 33.83 | 1024x1024 |
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
| ulysses_anything | 253.16 | 3.95 | 33.75 | 1008x1008 |
| ring_anything | 335.57 | 2.98 | 33.75 | 1008x1008 |

From the above table, Ring Anything Attention offers compatibility with arbitrary sequence lengths while maintaining performance comparable to the standard Ring Attention.
Comment thread
sayakpaul marked this conversation as resolved.

### parallel_config

Pass `parallel_config` during model initialization to enable context parallelism.
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
)
return x
else:
if self.parallel_config.ulysses_anything:
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
Expand Down Expand Up @@ -239,7 +239,7 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
if self.parallel_config.ulysses_anything:
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class ContextParallelConfig:
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
ring_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, `ring_degree`
must be greater than 1 and `ulysses_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
Expand All @@ -82,6 +85,8 @@ class ContextParallelConfig:
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
# Whether to enable ring anything attention to support any sequence lengths.
ring_anything: bool = False

_rank: int = None
_world_size: int = None
Expand Down Expand Up @@ -114,6 +119,13 @@ def __post_init__(self):
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
if self.ring_anything:
if self.ring_degree == 1:
raise ValueError("ring_degree must be greater than 1 for ring_anything to be enabled.")
if self.ulysses_degree > 1:
raise ValueError("ring_anything cannot be enabled when ulysses_degree > 1.")
if self.ulysses_anything and self.ring_anything:
raise ValueError("ulysses_anything and ring_anything cannot both be enabled.")

@property
def mesh_shape(self) -> tuple[int, int]:
Expand Down
157 changes: 143 additions & 14 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,119 @@ def backward(
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None


class TemplatedRingAnythingAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None,
dropout_p: float,
is_causal: bool,
scale: float | None,
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: "ParallelConfig" | None = None,
):
# Ring attention for arbitrary sequence lengths.
if attn_mask is not None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a pretty big limitation no? This would make it incompatible with models like QwenImage, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's fair. For this PR I'm keeping ring_anything scoped to the attn_mask=None path, which covers the FLUX/Wan cases I tested. QwenImage masks should be supportable, but I'd prefer to add that in a follow-up with proper validation.

raise ValueError(
"TemplatedRingAnythingAttention does not support non-None attn_mask: "
"non-uniform sequence lengths across ranks make cross-rank mask slicing ambiguous."
)
ring_mesh = _parallel_config.context_parallel_config._ring_mesh
group = ring_mesh.get_group()
rank = _parallel_config.context_parallel_config._ring_local_rank
world_size = _parallel_config.context_parallel_config.ring_degree
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None

ctx.forward_op = forward_op
ctx.backward_op = backward_op
ctx.q_shape = query.shape
ctx.kv_shape = key.shape
ctx._parallel_config = _parallel_config

kv_seq_len = key.shape[1] # local S_KV (may differ across ranks)
all_kv_seq_lens = gather_size_by_comm(kv_seq_len, group)
s_max = max(all_kv_seq_lens)

# Padding is applied on the sequence dimension (dim=1) at the end.
def pad_to_s_max(t: torch.Tensor) -> torch.Tensor:
pad_len = s_max - t.shape[1]
if pad_len == 0:
return t
pad_shape = list(t.shape)
pad_shape[1] = pad_len
return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1)
Comment thread
sayakpaul marked this conversation as resolved.
Outdated

key_padded = pad_to_s_max(key)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add a small explainer comment.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment here.

value_padded = pad_to_s_max(value)

kv_buffer = torch.cat([key_padded.flatten(), value_padded.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=group)
kv_buffer = kv_buffer.chunk(world_size)

# numel per-rank in the padded layout
kv_padded_numel = key_padded.numel()

for i in range(world_size):
if i > 0:
true_seq_len = all_kv_seq_lens[next_rank]
kv = kv_buffer[next_rank]
# Reshape to padded shape, then slice to true sequence length
key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len]
value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len]
next_rank = (next_rank + 1) % world_size
else:
# i == 0: use local (unpadded) key/value
key = key_padded[:, :kv_seq_len]
value = value_padded[:, :kv_seq_len]

out, lse = forward_op(
ctx,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
True,
Comment on lines +2141 to +2164
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@claude can we use torch.where here for a better conditional flow graph?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I tried removing the branch, but it regressed on PCIe since the local chunk has to be read back from the all-gather buffer. Keeping the local fast path for now, but happy to revisit if there’s a cleaner way to express it.

_save_ctx=i == 0,
_parallel_config=_parallel_config,
)

if _parallel_config.context_parallel_config.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)

if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse

out = out.to(query.dtype)
lse = lse.squeeze(-1)

return (out, lse) if return_lse else out

@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
):
raise NotImplementedError("Backward pass for Ring Anything Attention in diffusers is not implemented yet.")


class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -2254,20 +2367,36 @@ def _templated_context_parallel_attention(
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if _parallel_config.context_parallel_config.ring_anything:
return TemplatedRingAnythingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
return TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
if _parallel_config.context_parallel_config.ulysses_anything:
# For Any sequence lengths and Any head num support
Expand Down
Loading