Skip to content

Fix the QwenImage Attention mask under Ulysses SP#13756

Open
zhtmike wants to merge 7 commits into
huggingface:mainfrom
zhtmike:fix_qwen_mask
Open

Fix the QwenImage Attention mask under Ulysses SP#13756
zhtmike wants to merge 7 commits into
huggingface:mainfrom
zhtmike:fix_qwen_mask

Conversation

@zhtmike
Copy link
Copy Markdown
Contributor

@zhtmike zhtmike commented May 15, 2026

What does this PR do?

This fixes the issue #13696 . The test should be passed after this PR.

This the problem I found: The mask does not have a one-to-one correspondence with the content.

For QwenImage Pipeline, use the following example

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Content:  T0 T1 T2 T3 T4 T5 T6 T7 | I0 I1 I2 I3
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

After CP shard (assume 2 ranks)

Rank 0: text=[T0 T1 T2 T3],  image=[I0 I1]  → joint=[T0 T1 T2 T3 I0 I1]
Rank 1: text=[T4 T5 T6 T7],  image=[I2 I3]  → joint=[T4 T5 T6 T7 I2 I3]

After All-to-all

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
           ← rank 0 →          ← rank 1 →

But the mask is not handled correctly

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

This PR makes mask correctly assigned

Position:  0  1  2  3 | 4  5 | 6  7  8  9 | 10 11
Mask:      1  1  0  0 | 1  1 | 0  0  0  0 | 1  1

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions Bot added models size/S PR with diff < 50 LOC labels May 15, 2026
@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented May 15, 2026

@sayakpaul, I can make this fix more generic by addressing TemplatedUlyssesAttention directly, which should help prevent similar errors in the future.

Here’s my suggestion:

  • Maintain separate local masks for image and text.
  • During all-to-all operations, ensure the mask and the corresponding content always have a one-to-one relationship.
    This way, we avoid numerical inconsistencies caused by mask-related issues.

However, this approach requires handling masks in attention layer locally (and with extra communication cost) and may revert some of the performance improvements introduced in #12702.

So do you have any suggestions?

@sayakpaul
Copy link
Copy Markdown
Member

Could you explain a bit why the actual mask is

Position:  0  1  2  3 | 4  5 | 6  7  8  9 | 10 11
Mask:      1  1  0  0 | 1  1 | 0  0  0  0 | 1  1

@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented May 15, 2026

Could you explain a bit why the actual mask is

Position:  0  1  2  3 | 4  5 | 6  7  8  9 | 10 11
Mask:      1  1  0  0 | 1  1 | 0  0  0  0 | 1  1

since the original input is (T: token for text; I: Token for image)

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Content:  T0 T1 T2 T3 T4 T5 T6 T7 | I0 I1 I2 I3
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

after all-to-all

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
           ← rank 0 →          ← rank 1 →

so the mask should be in the correct position matching the content, avoid the wrong attention of KV

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
Mask:      1  1  0  0 | 1  1  | 0  0  0  0  | 1  1

@sayakpaul
Copy link
Copy Markdown
Member

I need @DN6's thoughts on this one. Dhruv, should we tackle it from the forward pass of the model or delegate to the Ulysses utilities we have? Personally, I would prefer doing it from the Ulysses utilities as it would help to keep the forward clean.

@sayakpaul
Copy link
Copy Markdown
Member

@claude-2-serge could you do a review?

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

The fix correctly identifies that under Ulysses sequence parallelism, the attention mask must be interleaved to match the post-all-to-all token layout. The approach of chunking text and image masks separately and interleaving them is sound.

Potential Issue

  • _parallel_config not declared on the transformer model class: getattr(self, "_parallel_config", None) at line 919 references self which is the QwenImageTransformer2DModel instance. However, _parallel_config = None is only declared on QwenDoubleStreamAttnProcessor2_0 (line 480), not on the transformer model class itself. Other transformer models (e.g., transformer_flux.py) declare _parallel_config on both the processor and the model class. If the parallel setup code sets _parallel_config on the model via setattr, this works — but it would be more robust and consistent to also declare _parallel_config = None as a class attribute on QwenImageTransformer2DModel.

Style

  • Trailing whitespace on line 926 (else: ).

Tests

  • The PR description references issue #13696 and claims the test should pass, but no new tests are added in this diff. It would be good to confirm the existing test coverage is sufficient.

16 LLM turns · 20 tool calls · 92.9s · 217527 in / 3766 out tokens

Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
@DN6
Copy link
Copy Markdown
Collaborator

DN6 commented May 18, 2026

I need @DN6's thoughts on this one. Dhruv, should we tackle it from the forward pass of the model or delegate to the Ulysses utilities we have? Personally, I would prefer doing it from the Ulysses utilities as it would help to keep the forward clean.

Yeah, it would be better to handle via the CP utilities and keep the forward as clean as possible.

@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented May 19, 2026

Hi @sayakpaul and @DN6 , I've updated the PR using local mask handling.

This is what happens now:

Inputs (T for text token, I for image token):

Position:  0  1  2  3  4  5  6  7 | 8  9  10 11
Content:  T0 T1 T2 T3 T4 T5 T6 T7 | I0 I1 I2 I3
Mask:      1  1  0  0  0  0  0  0 | 1  1  1  1

After CP shard (assume 2 ranks)

Rank 0: text=[T0 T1 T2 T3],  image=[I0 I1], encoder mask=[1,1,0,0], image mask=[1,1] 
→ joint=[T0 T1 T2 T3 I0 I1], mask=[1,1,0,0,1,1]

Rank 1: text=[T4 T5 T6 T7],  image=[I2 I3], encoder mask=[0,0,0,0], image mask=[1,1]  
→ joint=[T4 T5 T6 T7 I2 I3], mask=[0,0,0,0,1,1]

After QKV All-to-all and mask all-gather

Position:  0  1  2  3 | 4  5  | 6  7  8  9  | 10 11
Content:  T0 T1 T2 T3 | I0 I1 | T4 T5 T6 T7 | I2 I3
Mask:      1  1  0  0 | 1  1  |  0  0  0  0 | 1  1

The KV and mask should have the one-to-one correspondence now.

Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
Comment thread src/diffusers/models/attention_dispatch.py
@github-actions github-actions Bot added size/M PR with diff < 200 LOC and removed size/S PR with diff < 50 LOC labels May 20, 2026
@sayakpaul
Copy link
Copy Markdown
Member

@claude-2-serge could you do a review?

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a test?

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

This PR correctly fixes the attention mask handling under Ulysses sequence parallelism for QwenImage by moving joint mask construction from the transformer's forward (where it was built once from unsplit tensors) into the per-block attention processor (where it's built from already-sharded encoder_hidden_states_mask and hidden_states). The attention dispatch layers then all-gather the local mask to match the post-all-to-all QKV layout.

Correctness

  • The core fix is sound: under CP, encoder_hidden_states_mask is now split via the _cp_plan (dim=1), passed into each block, and the processor builds the joint mask from the sharded text mask + sharded image ones. The all-gather in TemplatedUlyssesAttention and TemplatedUlyssesAnythingAttention then reconstructs the global mask in the correct rank-concatenated order matching the QKV layout.
  • The "anything" path correctly handles unequal local sizes with pad → all-gather → trim.
  • The guard attn_mask.shape[-1] == S_KV_LOCAL correctly identifies masks that need gathering (local masks) vs. masks that are already global.
  • F.pad on a bool mask defaults to padding with False, which is correct (padded positions should be masked out before trimming).

Minor issues

  • The type annotation attention_mask: None = None is technically valid but unconventional and could confuse tooling/users. A more standard approach would be to keep the original type hint and rely solely on the runtime ValueError.
  • The mask is reconstructed from scratch in every block (cat + unsqueeze), which adds minor overhead compared to the previous approach of building it once. This is the necessary trade-off for correctness under CP, but worth noting.

Suggestions / additional info (dead code trace)

Under the default pipeline call path, the encoder_hidden_states_mask flows correctly from the transformer forward → block → processor → dispatch_attention_fn. The attention_mask parameter on the processor is now effectively dead (always None from external callers, raises if not), which is the intended design — the processor owns mask construction.

23 LLM turns · 26 tool calls · 150.3s · 485476 in / 5617 out tokens

Comment thread src/diffusers/models/transformers/transformer_qwenimage.py Outdated
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: the comment says "All-gather" but the semantic intent is closer to "all-gather the per-rank local masks and concatenate them so the mask covers the full (global) KV sequence after the Ulysses all-to-all on QKV." The current comment is fine but could be slightly more precise about why the layout matches — the all-to-all on QKV concatenates sequence chunks from each rank in rank order, and the all-gather + cat on the mask does the same.

zhtmike and others added 3 commits May 20, 2026 11:29
@github-actions github-actions Bot added the tests label May 20, 2026
@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented May 20, 2026

Should there be a test?

Done. Added an accuracy test under ContextParallelTesterMixin. This test should guard this (main branch should fail).
And I've tested with other two models flux and flux2. So I think the threshold should be fine.


# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif here since this would revert the optimization as part of this PR #12702

I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/M PR with diff < 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants