Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,12 @@ def forward(
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: the comment says "All-gather" but the semantic intent is closer to "all-gather the per-rank local masks and concatenate them so the mask covers the full (global) KV sequence after the Ulysses all-to-all on QKV." The current comment is fine but could be slightly more precise about why the layout matches — the all-to-all on QKV concatenates sequence chunks from each rank in rank order, and the all-gather + cat on the mask does the same.

# All-gather a local mask so its layout matches the QKV layout after all-to-all.
mask_list = [torch.empty_like(attn_mask) for _ in range(world_size)]
dist.all_gather(mask_list, attn_mask, group=group)
attn_mask = torch.cat(mask_list, dim=-1)
Comment thread
sayakpaul marked this conversation as resolved.

out = forward_op(
ctx,
query,
Expand Down Expand Up @@ -2399,6 +2405,8 @@ def forward(
ctx.backward_op = backward_op
ctx._parallel_config = _parallel_config

_, S_KV_LOCAL, _, _ = key.shape

metadata = ulysses_anything_metadata(query)
query_wait = all_to_all_single_any_qkv_async(query, group, **metadata)
key_wait = all_to_all_single_any_qkv_async(key, group, **metadata)
Expand All @@ -2408,6 +2416,19 @@ def forward(
key = key_wait() # type: torch.Tensor
value = value_wait() # type: torch.Tensor

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
# All-gather a local mask to match the post-all-to-all global sequence.
# The "anything" path allows unequal local sizes, so we pad to the
# maximum across ranks before all-gathering, then trim back.
mask_local_sizes = gather_size_by_comm(attn_mask.shape[-1], group)
max_local = max(mask_local_sizes)
if attn_mask.shape[-1] < max_local:
attn_mask = F.pad(attn_mask, (0, max_local - attn_mask.shape[-1]))
mask_list = [torch.empty_like(attn_mask) for _ in range(dist.get_world_size(group=group))]
dist.all_gather(mask_list, attn_mask, group=group)
attn_mask = torch.cat(mask_list, dim=-1)
attn_mask = attn_mask[..., : sum(mask_local_sizes)]

out = forward_op(
ctx,
query,
Expand Down
34 changes: 18 additions & 16 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,24 @@ def __call__(
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: torch.FloatTensor | None = None,
attention_mask: None = None,
Comment thread
zhtmike marked this conversation as resolved.
Outdated
image_rotary_emb: torch.Tensor | None = None,
) -> torch.FloatTensor:
if encoder_hidden_states is None:
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")

if attention_mask is not None:
raise ValueError(
"QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
"Pass encoder_hidden_states_mask to let the processor build the joint mask."
)

if encoder_hidden_states_mask is not None:
seq_img = hidden_states.shape[1]
image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device)
attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
attention_mask = attention_mask[:, None, None, :]

seq_txt = encoder_hidden_states.shape[1]

# Compute QKV for image stream (sample projections)
Expand Down Expand Up @@ -770,6 +782,7 @@ class QwenImageTransformer2DModel(
},
"transformer_blocks.*": {
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
Expand Down Expand Up @@ -909,38 +922,27 @@ def forward(

image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)

# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif here since this would revert the optimization as part of this PR #12702

I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP

if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
encoder_hidden_states_mask,
temb,
image_rotary_emb,
block_attention_kwargs,
attention_kwargs,
modulate_index,
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=block_attention_kwargs,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
)

Expand Down
Loading