Skip to content

Commit 7f43cb1

Browse files
authored
fix Qwen-Image series context parallel (#12970)
* fix qwen-image cp * relax attn_mask limit for cp * CP plan compatible with zero_cond_t * move modulate_index plan to top level
1 parent 5efb81f commit 7f43cb1

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,8 +1573,6 @@ def _templated_context_parallel_attention(
15731573
backward_op,
15741574
_parallel_config: Optional["ParallelConfig"] = None,
15751575
):
1576-
if attn_mask is not None:
1577-
raise ValueError("Attention mask is not yet supported for templated attention.")
15781576
if is_causal:
15791577
raise ValueError("Causal attention is not yet supported for templated attention.")
15801578
if enable_gqa:

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,11 +761,14 @@ class QwenImageTransformer2DModel(
761761
_no_split_modules = ["QwenImageTransformerBlock"]
762762
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
763763
_repeated_blocks = ["QwenImageTransformerBlock"]
764+
# Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
764765
_cp_plan = {
765-
"": {
766+
"transformer_blocks.0": {
766767
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
767768
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
768-
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
769+
},
770+
"transformer_blocks.*": {
771+
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
769772
},
770773
"pos_embed": {
771774
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),

0 commit comments

Comments
 (0)