Skip to content

Commit bcbbded

Browse files
zhangtao0408TaoZhang-Worksayakpaulgithub-actions[bot]DN6
authored
[Bug] Fix QwenImageEditPlus Series on NPU (#13017)
* [Bug Fix][Qwen-Image-Edit] Fix Qwen-Image-Edit series on NPU * Enhance NPU attention handling by converting attention mask to boolean and refining mask checks. * Refine attention mask handling in NPU attention function to improve validation and conversion logic. * Clean Code * Refine attention mask processing in NPU attention functions to enhance performance and validation. * Remove item() ops on npu fa backend. * Reuse NPU attention mask by `_maybe_modify_attn_mask_npu` * Apply style fixes * Update src/diffusers/models/attention_dispatch.py --------- Co-authored-by: zhangtao <zhangtao529@huawei.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 35086ac commit bcbbded

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,26 @@ def _sage_attention_backward_op(
11171117
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
11181118

11191119

1120+
def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
1121+
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
1122+
if attn_mask is not None and torch.all(attn_mask != 0):
1123+
attn_mask = None
1124+
1125+
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
1126+
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
1127+
if (
1128+
attn_mask is not None
1129+
and attn_mask.ndim == 2
1130+
and attn_mask.shape[0] == query.shape[0]
1131+
and attn_mask.shape[1] == key.shape[1]
1132+
):
1133+
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
1134+
attn_mask = ~attn_mask.to(torch.bool)
1135+
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
1136+
1137+
return attn_mask
1138+
1139+
11201140
def _npu_attention_forward_op(
11211141
ctx: torch.autograd.function.FunctionCtx,
11221142
query: torch.Tensor,
@@ -1134,11 +1154,14 @@ def _npu_attention_forward_op(
11341154
if return_lse:
11351155
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
11361156

1157+
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
1158+
11371159
out = npu_fusion_attention(
11381160
query,
11391161
key,
11401162
value,
11411163
query.size(2), # num_heads
1164+
atten_mask=attn_mask,
11421165
input_layout="BSND",
11431166
pse=None,
11441167
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -2668,16 +2691,17 @@ def _native_npu_attention(
26682691
return_lse: bool = False,
26692692
_parallel_config: "ParallelConfig" | None = None,
26702693
) -> torch.Tensor:
2671-
if attn_mask is not None:
2672-
raise ValueError("`attn_mask` is not supported for NPU attention")
26732694
if return_lse:
26742695
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
26752696
if _parallel_config is None:
2697+
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
2698+
26762699
out = npu_fusion_attention(
26772700
query,
26782701
key,
26792702
value,
26802703
query.size(2), # num_heads
2704+
atten_mask=attn_mask,
26812705
input_layout="BSND",
26822706
pse=None,
26832707
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -2692,7 +2716,7 @@ def _native_npu_attention(
26922716
query,
26932717
key,
26942718
value,
2695-
None,
2719+
attn_mask,
26962720
dropout_p,
26972721
None,
26982722
scale,

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask(
164164
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
165165
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
166166
has_active = encoder_hidden_states_mask.any(dim=1)
167-
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
167+
per_sample_len = torch.where(
168+
has_active,
169+
active_positions.max(dim=1).values + 1,
170+
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
171+
)
168172
return text_seq_len, per_sample_len, encoder_hidden_states_mask
169173

170174

0 commit comments

Comments
 (0)