Skip to content

Fix Qwen attention mask crash with diffusers >=0.37#748

Merged
jaretburkett merged 2 commits into
ostris:mainfrom
Rasaboun:fix/qwen-mask
Mar 23, 2026
Merged

Fix Qwen attention mask crash with diffusers >=0.37#748
jaretburkett merged 2 commits into
ostris:mainfrom
Rasaboun:fix/qwen-mask

Conversation

@Rasaboun
Copy link
Copy Markdown
Contributor

Problem

diffusers v0.37 (PR #12987) optimizes all-ones attention masks to None in encode_prompt() when there is no padding. This causes an AttributeError: 'NoneType' object has no attribute 'to' crash in all three Qwen image extensions during training.

The recent fix in 295094b addresses qwen_image.py by switching to the private _get_qwen_prompt_embeds() API, but qwen_image_edit.py and qwen_image_edit_plus.py still crash.

Fix

Reconstruct the all-ones mask right after encode_prompt() returns None, in get_prompt_embeds() of all three Qwen variants:

prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(...)
if prompt_embeds_mask is None:
    prompt_embeds_mask = torch.ones(
        prompt_embeds.shape[:2], device=prompt_embeds.device, dtype=torch.int64
    )

This approach:

  • Uses the public encode_prompt() API (not _get_qwen_prompt_embeds)
  • Works with both old and new diffusers versions
  • Requires zero downstream changes — the mask is always a tensor
  • Fixes all three variants: base, edit, and edit_plus

Also removes redundant duplicate mask assignments in qwen_image_edit.py and qwen_image_edit_plus.py.

Fixes #740

diffusers v0.37 (PR #12987) optimizes all-ones attention masks to None
in encode_prompt() when there is no padding. This breaks ai-toolkit's
Qwen extensions which call .to() on the mask unconditionally.

Fix: reconstruct the all-ones mask at the boundary (get_prompt_embeds)
right after encode_prompt() returns. This keeps the rest of the code
unchanged and works with both old and new diffusers versions.

Also removes redundant duplicate mask assignments in qwen_image_edit
and qwen_image_edit_plus.

Fixes ostris#740
Copilot AI review requested due to automatic review settings March 23, 2026 20:26
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a training-time crash in the Qwen image model extensions when used with diffusers>=0.37, where encode_prompt() may return None for an all-ones attention mask (no padding).

Changes:

  • Reconstruct an all-ones prompt_embeds_mask tensor when encode_prompt() returns None in all Qwen variants (base/edit/edit_plus).
  • Remove redundant re-assignment of prompt_embeds_mask from text_embeddings.attention_mask in the edit/edit_plus noise prediction paths.
  • Ensure the mask passed into the transformer is a tensor (and in qwen_image_edit.py, passed as a detached tensor).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
extensions_built_in/diffusion_models/qwen_image/qwen_image.py Switches prompt embedding extraction back to public encode_prompt() and guards against None masks by recreating an all-ones mask.
extensions_built_in/diffusion_models/qwen_image/qwen_image_edit.py Adds the same None-mask reconstruction and removes redundant mask reassignment; passes a detached mask into the transformer.
extensions_built_in/diffusion_models/qwen_image/qwen_image_edit_plus.py Adds the same None-mask reconstruction and removes redundant mask reassignment in noise prediction.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@jaretburkett jaretburkett merged commit 99a4a58 into ostris:main Mar 23, 2026
3 of 4 checks passed
@jaretburkett
Copy link
Copy Markdown
Contributor

Thank you! @Rasaboun

mocliamg1 pushed a commit to mocliamg1/ai-toolkit that referenced this pull request Apr 23, 2026
* Fix Qwen Image mask handling

* Fix Qwen attention mask crash with diffusers >=0.37

diffusers v0.37 (PR #12987) optimizes all-ones attention masks to None
in encode_prompt() when there is no padding. This breaks ai-toolkit's
Qwen extensions which call .to() on the mask unconditionally.

Fix: reconstruct the all-ones mask at the boundary (get_prompt_embeds)
right after encode_prompt() returns. This keeps the rest of the code
unchanged and works with both old and new diffusers versions.

Also removes redundant duplicate mask assignments in qwen_image_edit
and qwen_image_edit_plus.

Fixes ostris#740
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

error when training qwen-image-2512

3 participants