Skip to content

The backward pass of QwenImageTransformer failed with Ulysses SP. #13319

@zhtmike

Description

@zhtmike

Describe the bug

I am not sure whether the backward pass of Ulysses SP is formally supported, but I found that backward ops like _native_attention_backward_op is implemented in the codebase. When I try to run QwenImageTransformer with the backward pass, I encounter errors related to shape mismatches.

Reproduction

We can reproduce the results with the following code snippets (relies on the PR #13278 to fix the forward pass first)

import argparse

import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.models import QwenImageTransformer2DModel
from diffusers.models._modeling_parallel import ContextParallelConfig


def init_model(device, enable_sp: bool):
    model = QwenImageTransformer2DModel(
        num_layers=2,
        num_attention_heads=4,
        attention_head_dim=32,
        joint_attention_dim=3584,
        axes_dims_rope=(8, 12, 12),
    ).to(device, dtype=torch.bfloat16)

    if enable_sp:
        model.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))

    return model


def make_batch(device):
    dtype = torch.bfloat16
    torch.manual_seed(0)
    hidden_states = torch.randn(2, 256, 64, device=device, dtype=dtype)
    encoder_hidden_states = torch.randn(2, 32, 3584, device=device, dtype=dtype)
    encoder_hidden_states_mask = torch.ones(2, 32, device=device, dtype=torch.bool)
    timestep = torch.rand(2, device=device, dtype=dtype)
    img_shapes = [[(1, 16, 16)]] * 2
    target = torch.randn(2, 256, 64, device=device, dtype=dtype)
    return (
        hidden_states,
        encoder_hidden_states,
        encoder_hidden_states_mask,
        timestep,
        img_shapes,
        target,
    )


def train(enable_sp: bool):
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank)
    torch.cuda.set_device(device)

    model = init_model(device, enable_sp)
    model.train()

    (
        hidden_states,
        encoder_hidden_states,
        encoder_hidden_states_mask,
        timestep,
        img_shapes,
        target,
    ) = make_batch(device)

    pred = model(
        hidden_states=hidden_states,
        encoder_hidden_states=encoder_hidden_states,
        encoder_hidden_states_mask=encoder_hidden_states_mask,
        timestep=timestep,
        img_shapes=img_shapes,
        return_dict=False,
    )[0]

    loss = F.mse_loss(pred.float(), target.float())
    loss.backward()

    if rank == 0:
        print(f"loss={loss.item():.6f}")

    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--enable-sp", action="store_true")
    args = parser.parse_args()
    train(enable_sp=args.enable_sp)

Logs

  • Without Ulysses SP enabled,
torchrun --nproc-per-node 2 toy_train.py

the script runs and produces the expected output:

loss=1.351188
  • With Ulysses SP enabled
torchrun --nproc-per-node 2 toy_train.py --enable-sp

it runs with error

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([2, 2, 288, 32]) and output[0] has a shape of torch.Size([2, 288, 2, 32]).

System Info

  • 🤗 Diffusers version: 0.38.0.dev0
  • Platform: Linux-5.15.0-1053-nvidia-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.12
  • PyTorch version (GPU?): 2.9.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.2
  • Transformers version: 4.57.6
  • Accelerate version: 1.12.0
  • PEFT version: 0.18.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA H800, 81559 MiB
    NVIDIA H800, 81559 MiB
    NVIDIA H800, 81559 MiB
    NVIDIA H800, 81559 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions