-
Notifications
You must be signed in to change notification settings - Fork 7k
Fix the QwenImage Attention mask under Ulysses SP #13756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
84c545f
1d391b5
9346da0
026327f
52bd940
03017a6
7d21597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -497,6 +497,18 @@ def __call__( | |
| if encoder_hidden_states is None: | ||
| raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)") | ||
|
|
||
| if attention_mask is not None: | ||
| raise ValueError( | ||
| "QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. " | ||
| "Pass encoder_hidden_states_mask to let the processor build the joint mask." | ||
| ) | ||
|
|
||
| if encoder_hidden_states_mask is not None: | ||
| seq_img = hidden_states.shape[1] | ||
| image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device) | ||
| attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) | ||
| attention_mask = attention_mask[:, None, None, :] | ||
|
|
||
| seq_txt = encoder_hidden_states.shape[1] | ||
|
|
||
| # Compute QKV for image stream (sample projections) | ||
|
|
@@ -770,6 +782,7 @@ class QwenImageTransformer2DModel( | |
| }, | ||
| "transformer_blocks.*": { | ||
| "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), | ||
| "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), | ||
| }, | ||
| "pos_embed": { | ||
| 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), | ||
|
|
@@ -909,38 +922,27 @@ def forward( | |
|
|
||
| image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) | ||
|
|
||
| # Construct joint attention mask once to avoid reconstructing in every block | ||
| # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility | ||
| block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @kashif here since this would revert the optimization as part of this PR #12702 I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP |
||
| if encoder_hidden_states_mask is not None: | ||
| # Build joint mask: [text_mask, all_ones_for_image] | ||
| batch_size, image_seq_len = hidden_states.shape[:2] | ||
| image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) | ||
| joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) | ||
| joint_attention_mask = joint_attention_mask[:, None, None, :] | ||
| block_attention_kwargs["attention_mask"] = joint_attention_mask | ||
|
|
||
| for index_block, block in enumerate(self.transformer_blocks): | ||
| if torch.is_grad_enabled() and self.gradient_checkpointing: | ||
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( | ||
| block, | ||
| hidden_states, | ||
| encoder_hidden_states, | ||
| None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) | ||
| encoder_hidden_states_mask, | ||
| temb, | ||
| image_rotary_emb, | ||
| block_attention_kwargs, | ||
| attention_kwargs, | ||
| modulate_index, | ||
| ) | ||
|
|
||
| else: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, | ||
| encoder_hidden_states=encoder_hidden_states, | ||
| encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) | ||
| encoder_hidden_states_mask=encoder_hidden_states_mask, | ||
| temb=temb, | ||
| image_rotary_emb=image_rotary_emb, | ||
| joint_attention_kwargs=block_attention_kwargs, | ||
| joint_attention_kwargs=attention_kwargs, | ||
| modulate_index=modulate_index, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: the comment says "All-gather" but the semantic intent is closer to "all-gather the per-rank local masks and concatenate them so the mask covers the full (global) KV sequence after the Ulysses all-to-all on QKV." The current comment is fine but could be slightly more precise about why the layout matches — the all-to-all on QKV concatenates sequence chunks from each rank in rank order, and the all-gather + cat on the mask does the same.