Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
dd56429
feat: support Ulysses Anything Attention
DefTruth Jan 19, 2026
123f526
feat: support Ulysses Anything Attention
DefTruth Jan 19, 2026
af9af62
feat: support Ulysses Anything Attention
DefTruth Jan 19, 2026
40800e5
Merge branch 'main' into ulysses-anything
DefTruth Jan 19, 2026
0918547
Merge branch 'main' into ulysses-anything
DefTruth Jan 20, 2026
b4d3f07
feat: support Ulysses Anything Attention
DefTruth Jan 20, 2026
ece9f15
Merge branch 'main' into ulysses-anything
DefTruth Jan 20, 2026
7f9c412
Merge branch 'main' into ulysses-anything
DefTruth Jan 21, 2026
403c204
fix UAA broken while using joint attn
DefTruth Jan 21, 2026
9280e2b
update
DefTruth Jan 21, 2026
a5d2459
Merge branch 'main' into ulysses-anything
DefTruth Jan 22, 2026
de157c8
Merge branch 'main' into ulysses-anything
DefTruth Jan 23, 2026
4caa87e
post check
DefTruth Jan 23, 2026
140ece8
add docs
DefTruth Jan 26, 2026
7f2ecf4
Merge branch 'main' into ulysses-anything
DefTruth Jan 26, 2026
4cc56b9
add docs
DefTruth Jan 26, 2026
fd88b2d
Merge branch 'ulysses-anything' of https://github.com/xlite-dev/diffu…
DefTruth Jan 26, 2026
f1c32d9
Merge branch 'main' into ulysses-anything
DefTruth Jan 28, 2026
a74820b
remove lru cache
DefTruth Jan 28, 2026
d141801
Merge branch 'main' into ulysses-anything
DefTruth Jan 29, 2026
f8f2209
move codes
DefTruth Jan 29, 2026
9924f8b
Merge branch 'main' into ulysses-anything
DefTruth Jan 29, 2026
e1d83eb
update
DefTruth Jan 30, 2026
e426b64
Merge branch 'ulysses-anything' of https://github.com/xlite-dev/diffu…
DefTruth Jan 30, 2026
cb6c4fd
Merge branch 'main' into ulysses-anything
DefTruth Jan 30, 2026
ef919b4
Merge branch 'main' into ulysses-anything
DefTruth Feb 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,34 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](

From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.


### Ulysses Anything Attention

The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.

[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].

```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
```

> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.

We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:

| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|--------------------|------------------|-------------|------------------|------------|
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
Comment on lines +367 to +369
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this from a failed eval? Can it be removed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep the failed results here to demonstrate that Ulysses Anything can handle cases that the standard Ulysses, Ring and USP fail to process.

| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |

From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.

### parallel_config

Pass `parallel_config` during model initialization to enable context parallelism.
Expand Down
14 changes: 13 additions & 1 deletion src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ContextParallelModelPlan,
ContextParallelOutput,
)
from ..models._ulysses_anything_utils import PartitionAnythingSharder
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -208,6 +209,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
)
return x
else:
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)


Expand All @@ -233,7 +238,14 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
if self.parallel_config.ulysses_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)

return output[0] if is_tensor else tuple(output)

Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
Comment thread
DefTruth marked this conversation as resolved.

_rank: int = None
_world_size: int = None
Expand Down Expand Up @@ -94,6 +97,11 @@ def __post_init__(self):
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self.ulysses_anything:
if self.ulysses_degree == 1:
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")

@property
def mesh_shape(self) -> Tuple[int, int]:
Expand Down
286 changes: 286 additions & 0 deletions src/diffusers/models/_ulysses_anything_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
Comment thread
DefTruth marked this conversation as resolved.
Outdated
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py
import copy
import functools
from typing import Callable, List, Tuple

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
import torch.nn.functional as F

from ..utils.torch_utils import maybe_allow_in_graph


# Helper functions for shape gathering
def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]:
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
return rank, world_size


def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
r"""Gather the local size from all ranks.
size: int, local size return: List[int], list of size from all ranks
"""
# NOTE(Serving/CP Safety):
# Do NOT cache this collective result.
#
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
# may legitimately differ across ranks. If we cache based on the *local* `size`,
# different ranks can have different cache hit/miss patterns across time.
#
# That can lead to a catastrophic distributed hang:
# - some ranks hit cache and *skip* dist.all_gather()
# - other ranks miss cache and *enter* dist.all_gather()
# This mismatched collective participation will stall the process group and
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
# timeouts in Ulysses attention).
world_size = dist.get_world_size(group=group)
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(
gathered_sizes,
torch.tensor([size], device=gather_device, dtype=torch.int64),
group=group,
)

gathered_sizes = [s[0].item() for s in gathered_sizes]
# NOTE: DON'T use tolist here due to graph break - Explanation:
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
return gathered_sizes


# Helper functions to pad/unpad head dimension for QKV and O projections
def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded
tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
"""
_, world_size = _get_rank_world_size(group)
H_PAD = 0
if H % world_size != 0:
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
# e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
# NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD


def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
"""
rank, world_size = _get_rank_world_size(group)
# Only the last rank may have padding
if H_PAD > 0 and rank == world_size - 1:
x = x[:, :, :-H_PAD, :]
return x.contiguous()


def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int],
padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
"""
if H is None:
return x, 0

rank, world_size = _get_rank_world_size(group)
H_PAD = 0
# Only the last rank may need padding
if H % world_size != 0:
# We need to broadcast H_PAD to all ranks to keep consistency
# in unpadding step later for all ranks.
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
if rank == world_size - 1:
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD


def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
"""
if H_PAD > 0:
x = x[:, :, :-H_PAD, :]
return x.contiguous()


# Helper functions to for all-to-all communication with Ulysses Anything Attention
def _wait_tensor(tensor) -> torch.Tensor:
if isinstance(tensor, fc.AsyncCollectiveTensor):
tensor = tensor.wait()

return tensor


def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict:
# query: (B, S_LOCAL, H_GLOBAL, D)
assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
extra_kwargs = {}
extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
extra_kwargs["Q_S_LOCAL"] = query.shape[1]
# Add other kwargs if needed in future
return extra_kwargs


@maybe_allow_in_graph
def all_to_all_single_any_qkv_async(
x: torch.Tensor, group: dist.ProcessGroup, **kwargs
) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
"""
_, world_size = _get_rank_world_size(group)
B, S_LOCAL, H, D = x.shape
x, H_PAD = _maybe_pad_qkv_head(x, H, group)
H_LOCAL = (H + H_PAD) // world_size
# (world_size, S_LOCAL, B, H_LOCAL, D)
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()

input_split_sizes = [S_LOCAL] * world_size
# S_LOCAL maybe not equal for all ranks in dynamic shape case,
# since we don't know the actual shape before this timing, thus,
# we have to use all gather to collect the S_LOCAL first.
output_split_sizes = _gather_size_by_comm(S_LOCAL, group)
x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)

def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
# (S_GLOBAL, B, H_LOCAL, D)
# -> (B, S_GLOBAL, H_LOCAL, D)
x = x.permute(1, 0, 2, 3).contiguous()
x = _maybe_unpad_qkv_head(x, H_PAD, group)
return x

return wait


@maybe_allow_in_graph
def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
"""
# Assume H is provided in kwargs, since we can't infer H from x's shape.
# The padding logic needs H to determine if padding is necessary.
H = kwargs.get("NUM_QO_HEAD", None)
rank, world_size = _get_rank_world_size(group)
x, H_PAD = _maybe_pad_o_head(x, H, group)
shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
(B, S_GLOBAL, H_LOCAL, D) = shape

# input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
# output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]

# WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
# from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
# c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
# b.tensor_split(4)[0].shape[1])

S_LOCAL = kwargs.get("Q_S_LOCAL")
input_split_sizes = _gather_size_by_comm(S_LOCAL, group)
x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
output_split_sizes = [S_LOCAL] * world_size
x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)

def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
x = x.permute(2, 1, 0, 3, 4).contiguous()
x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
x = _maybe_unpad_o_head(x, H_PAD, group)
return x

return wait


@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
gather_shapes = []
for i in range(world_size):
rank_shape = list(copy.deepcopy(shape))
rank_shape[dim] = gather_dims[i]
gather_shapes.append(rank_shape)
return gather_shapes


@maybe_allow_in_graph
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
_, world_size = _get_rank_world_size(group)
tensor = tensor.contiguous()
shape = tensor.shape
rank_dim = shape[dim]
gather_dims = _gather_size_by_comm(rank_dim, group)

gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)

gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]

dist.all_gather(gathered_tensors, tensor, group=group)
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
return gathered_tensor


class AllGatherAnythingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
ctx.dim = dim
ctx.group = group
ctx.world_size = dist.get_world_size(group)
ctx.rank = dist.get_rank(group)
gathered_tensor = _all_gather_anything(tensor, dim, group)
return gathered_tensor

@staticmethod
def backward(ctx, grad_output):
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
return grad_splits[ctx.rank], None, None


class PartitionAnythingSharder:
@classmethod
def shard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
assert tensor.size()[dim] >= mesh.size(), (
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
)
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]

@classmethod
def unshard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
return tensor
Loading