Fix the QwenImage Attention mask under Ulysses SP#13756
Conversation
|
@sayakpaul, I can make this fix more generic by addressing Here’s my suggestion:
However, this approach requires handling masks in attention layer locally (and with extra communication cost) and may revert some of the performance improvements introduced in #12702. So do you have any suggestions? |
|
Could you explain a bit why the actual mask is Position: 0 1 2 3 | 4 5 | 6 7 8 9 | 10 11
Mask: 1 1 0 0 | 1 1 | 0 0 0 0 | 1 1 |
since the original input is (T: token for text; I: Token for image) after all-to-all so the mask should be in the correct position matching the content, avoid the wrong attention of KV |
|
I need @DN6's thoughts on this one. Dhruv, should we tackle it from the forward pass of the model or delegate to the Ulysses utilities we have? Personally, I would prefer doing it from the Ulysses utilities as it would help to keep the forward clean. |
|
@claude-2-serge could you do a review? |
There was a problem hiding this comment.
🤗 Serge says:
The fix correctly identifies that under Ulysses sequence parallelism, the attention mask must be interleaved to match the post-all-to-all token layout. The approach of chunking text and image masks separately and interleaving them is sound.
Potential Issue
_parallel_confignot declared on the transformer model class:getattr(self, "_parallel_config", None)at line 919 referencesselfwhich is theQwenImageTransformer2DModelinstance. However,_parallel_config = Noneis only declared onQwenDoubleStreamAttnProcessor2_0(line 480), not on the transformer model class itself. Other transformer models (e.g.,transformer_flux.py) declare_parallel_configon both the processor and the model class. If the parallel setup code sets_parallel_configon the model viasetattr, this works — but it would be more robust and consistent to also declare_parallel_config = Noneas a class attribute onQwenImageTransformer2DModel.
Style
- Trailing whitespace on line 926 (
else:).
Tests
- The PR description references issue #13696 and claims the test should pass, but no new tests are added in this diff. It would be good to confirm the existing test coverage is sufficient.
16 LLM turns · 20 tool calls · 92.9s · 217527 in / 3766 out tokens
Yeah, it would be better to handle via the CP utilities and keep the forward as clean as possible. |
|
Hi @sayakpaul and @DN6 , I've updated the PR using local mask handling. This is what happens now: Inputs (T for text token, I for image token): After CP shard (assume 2 ranks) After QKV All-to-all and mask all-gather The KV and mask should have the one-to-one correspondence now. |
|
@claude-2-serge could you do a review? |
There was a problem hiding this comment.
🤗 Serge says:
This PR correctly fixes the attention mask handling under Ulysses sequence parallelism for QwenImage by moving joint mask construction from the transformer's forward (where it was built once from unsplit tensors) into the per-block attention processor (where it's built from already-sharded encoder_hidden_states_mask and hidden_states). The attention dispatch layers then all-gather the local mask to match the post-all-to-all QKV layout.
Correctness
- The core fix is sound: under CP,
encoder_hidden_states_maskis now split via the_cp_plan(dim=1), passed into each block, and the processor builds the joint mask from the sharded text mask + sharded image ones. The all-gather inTemplatedUlyssesAttentionandTemplatedUlyssesAnythingAttentionthen reconstructs the global mask in the correct rank-concatenated order matching the QKV layout. - The "anything" path correctly handles unequal local sizes with pad → all-gather → trim.
- The guard
attn_mask.shape[-1] == S_KV_LOCALcorrectly identifies masks that need gathering (local masks) vs. masks that are already global. F.padon a bool mask defaults to padding withFalse, which is correct (padded positions should be masked out before trimming).
Minor issues
- The type annotation
attention_mask: None = Noneis technically valid but unconventional and could confuse tooling/users. A more standard approach would be to keep the original type hint and rely solely on the runtimeValueError. - The mask is reconstructed from scratch in every block (cat + unsqueeze), which adds minor overhead compared to the previous approach of building it once. This is the necessary trade-off for correctness under CP, but worth noting.
Suggestions / additional info (dead code trace)
Under the default pipeline call path, the encoder_hidden_states_mask flows correctly from the transformer forward → block → processor → dispatch_attention_fn. The attention_mask parameter on the processor is now effectively dead (always None from external callers, raises if not), which is the intended design — the processor owns mask construction.
23 LLM turns · 26 tool calls · 150.3s · 485476 in / 5617 out tokens
| query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) | ||
| query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) | ||
|
|
||
| if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL: |
There was a problem hiding this comment.
Minor: the comment says "All-gather" but the semantic intent is closer to "all-gather the per-rank local masks and concatenate them so the mask covers the full (global) KV sequence after the Ulysses all-to-all on QKV." The current comment is fine but could be slightly more precise about why the layout matches — the all-to-all on QKV concatenates sequence chunks from each rank in rank order, and the all-gather + cat on the mask does the same.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Done. Added an accuracy test under |
|
|
||
| # Construct joint attention mask once to avoid reconstructing in every block | ||
| # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility | ||
| block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} |
There was a problem hiding this comment.
cc @kashif here since this would revert the optimization as part of this PR #12702
I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP
What does this PR do?
This fixes the issue #13696 . The test should be passed after this PR.
This the problem I found: The mask does not have a one-to-one correspondence with the content.
For QwenImage Pipeline, use the following example
After CP shard (assume 2 ranks)
After All-to-all
But the mask is not handled correctly
This PR makes mask correctly assigned
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.