Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2200,6 +2200,12 @@ def forward(
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: the comment says "All-gather" but the semantic intent is closer to "all-gather the per-rank local masks and concatenate them so the mask covers the full (global) KV sequence after the Ulysses all-to-all on QKV." The current comment is fine but could be slightly more precise about why the layout matches — the all-to-all on QKV concatenates sequence chunks from each rank in rank order, and the all-gather + cat on the mask does the same.

# All-gather a local mask so its layout matches the QKV layout after all-to-all.
mask_list = [torch.empty_like(attn_mask) for _ in range(world_size)]
dist.all_gather(mask_list, attn_mask, group=group)
attn_mask = torch.cat(mask_list, dim=-1)
Comment thread
sayakpaul marked this conversation as resolved.

out = forward_op(
ctx,
query,
Expand Down Expand Up @@ -2399,6 +2405,8 @@ def forward(
ctx.backward_op = backward_op
ctx._parallel_config = _parallel_config

_, S_KV_LOCAL, _, _ = key.shape

metadata = ulysses_anything_metadata(query)
query_wait = all_to_all_single_any_qkv_async(query, group, **metadata)
key_wait = all_to_all_single_any_qkv_async(key, group, **metadata)
Expand All @@ -2408,6 +2416,19 @@ def forward(
key = key_wait() # type: torch.Tensor
value = value_wait() # type: torch.Tensor

if attn_mask is not None and attn_mask.shape[-1] == S_KV_LOCAL:
# All-gather a local mask to match the post-all-to-all global sequence.
# The "anything" path allows unequal local sizes, so we pad to the
# maximum across ranks before all-gathering, then trim back.
mask_local_sizes = gather_size_by_comm(attn_mask.shape[-1], group)
max_local = max(mask_local_sizes)
if attn_mask.shape[-1] < max_local:
attn_mask = F.pad(attn_mask, (0, max_local - attn_mask.shape[-1]))
mask_list = [torch.empty_like(attn_mask) for _ in range(dist.get_world_size(group=group))]
dist.all_gather(mask_list, attn_mask, group=group)
attn_mask = torch.cat(mask_list, dim=-1)
attn_mask = attn_mask[..., : sum(mask_local_sizes)]

out = forward_op(
ctx,
query,
Expand Down
32 changes: 17 additions & 15 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,18 @@ def __call__(
if encoder_hidden_states is None:
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")

if attention_mask is not None:
raise ValueError(
"QwenDoubleStreamAttnProcessor2_0 does not accept an external attention_mask. "
"Pass encoder_hidden_states_mask to let the processor build the joint mask."
)

if encoder_hidden_states_mask is not None:
seq_img = hidden_states.shape[1]
image_mask = torch.ones((hidden_states.shape[0], seq_img), dtype=torch.bool, device=hidden_states.device)
attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
attention_mask = attention_mask[:, None, None, :]

seq_txt = encoder_hidden_states.shape[1]

# Compute QKV for image stream (sample projections)
Expand Down Expand Up @@ -770,6 +782,7 @@ class QwenImageTransformer2DModel(
},
"transformer_blocks.*": {
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
Expand Down Expand Up @@ -909,38 +922,27 @@ def forward(

image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)

# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif here since this would revert the optimization as part of this PR #12702

I tried to go over #12702, but was not able to find much detail about this optimization. I would love to understand more about the cause of the sync and performance delta, because the pre-built joint mask does not shard correctly under CP

if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
encoder_hidden_states_mask,
temb,
image_rotary_emb,
block_attention_kwargs,
attention_kwargs,
modulate_index,
)

else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=block_attention_kwargs,
joint_attention_kwargs=attention_kwargs,
modulate_index=modulate_index,
)

Expand Down
91 changes: 91 additions & 0 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,51 @@ def _custom_mesh_worker(
dist.destroy_process_group()


def _context_parallel_correctness_worker(
rank, world_size, master_port, model_class, init_dict, state_dict, cp_dict, inputs_dict, return_dict
):
"""Worker that runs a CP forward pass and returns the output tensor for numerical comparison."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]

dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")

model = model_class(**init_dict)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}

cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)

with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]

if rank == 0:
return_dict["status"] = "success"
# Serialise via nested list so the manager dict can transport it across processes.
return_dict["output"] = output.cpu().tolist()

except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()


@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
Expand Down Expand Up @@ -369,6 +414,52 @@ def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names)
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_output_correctness(self, cp_type, batch_size: int = 1):
"""Verify that CP output is numerically identical to a single-GPU reference forward pass."""
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")

if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")

world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)

# Single-GPU reference
model = self.model_class(**init_dict).eval().to(torch_device)
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
with torch.no_grad():
ref_output = model(**inputs_dict, return_dict=False)[0].cpu()

# Context-parallel run with the same weights
inputs_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}

master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()

mp.spawn(
_context_parallel_correctness_worker,
args=(world_size, master_port, self.model_class, init_dict, state_dict, cp_dict, inputs_cpu, return_dict),
nprocs=world_size,
join=True,
)

assert return_dict.get("status") == "success", (
f"Context parallel correctness check failed: {return_dict.get('error', 'Unknown error')}"
)

cp_output = torch.tensor(return_dict["output"])
torch.testing.assert_close(ref_output, cp_output, atol=1e-4, rtol=1e-4)


@is_attention
@is_context_parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, Attent
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""

def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
inputs = super().get_dummy_inputs(batch_size=batch_size)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
inputs["encoder_hidden_states_mask"] = encoder_hidden_states_mask
return inputs


class TestQwenImageTransformerContextParallelAttnBackends(
QwenImageTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
Expand Down
Loading