Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ def _native_attention_forward_op(
if return_lse:
raise ValueError("Native attention does not support return_lse=True")

if attn_mask is not None and attn_mask.dim() == 2:
attn_mask = attn_mask[:, None, None, :]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this Qwen specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't test with other models. But I think models with a 2D masks input should have the similar problem

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to check out one other? And also run the

class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. From a quick scan, most models seem to handle the mask shape correctly in their own implementations. So I’ve limited the modification to QwenImage only.

Should I run any test cases?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Maybe we could add a similar test to

class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
?

I will give you a ping once it's refactored to follow the latest pattern.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry disregard my suggestion on using the CUDNN backend.

Yes, native attention x Ulysses is perfect for single prompt input. Currently batch inputs have some problem.

Is it the case just for Qwen or the same happens for Flux, as well? Also, the test under consideration -- does it not use a single prompt?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case just for Qwen or the same happens for Flux, as well?

So far, I have only found that Qwen has this problem. Other models, such as Z-Image, HunyuanImage
expand the attention mask in a similar way before entering the attention block. For Flux, I tested with the main branch, and it works fine with both CP and batch inputs.

Also, the test under consideration -- does it not use a single prompt?

I am wondering whether we should add a batch input test if possible. At the beginning, I think we should first ensure that all unit tests pass without modifying them.

The background of this bug is that we are working on the training engine based on the Diffusers backend, using QwenImage as the first example. Therefore, we may need a combination of batch inputs (for high throughput) as well as Ulysses SP support. This is why we encountered this bug during the forward process.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed and thanks so much for the context!

I am wondering whether we should add a batch input test if possible. At the beginning, I think we should first ensure that all unit tests pass without modifying them.

Would you like to take a crack at this? We'll be quick to review.

I think first we need to ensure that the test_context_parallel_inference() test is xfailed when ring attention is enabled with the SDPA. #13182 is adding a test suite for CP-backends and attention backends.

And then a test for batched inputs.

Then let's revisit this PR?

WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure NP, I will add a UT test for batch input

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sayakpaul , I have added a PR #13312 for this, could you take a look?


# used for backward pass
if _save_ctx:
ctx.save_for_backward(query, key, value)
Expand Down