Skip to content

pixart_alpha model/pipeline review #13631

@hlky

Description

@hlky

pixart_alpha model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search: searched GitHub Issues and PRs for pixart_alpha, affected class/function/file names, and failure modes. Existing duplicates/related items found for Issue 1 and Issue 5; no duplicates found for the other report items.

Coverage status: fast and slow tests exist for PixArt Alpha, PixArt Sigma, and PixArtTransformer2DModel. No missing slow tests found for the listed target files. Current coverage misses the DPM one-step path, non-patch-aligned image sizes, mixed transformer/VAE dtype decode, lazy constant exports, and QKV unfuse edge cases.

Issue 1: Existing duplicate: PixArtAlpha one-step branch still breaks scheduler outputs

Affected code:

if num_inference_steps == 1:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

if not return_dict:
return (prev_sample,)

Problem:
PixArtAlphaPipeline indexes scheduler.step(..., return_dict=False)[1] whenever num_inference_steps == 1. DPMSolverMultistepScheduler, the pipeline’s declared scheduler type, returns a one-item tuple, so one-step inference crashes. Existing open duplicate/related issue: #8689 covers this same one-step special case for another scheduler.

Impact:
One-step PixArt Alpha inference is broken for standard scheduler outputs and for scheduler swaps.

Reproduction:

import torch
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtAlphaPipeline, PixArtTransformer2DModel

transformer = PixArtTransformer2DModel(
    sample_size=8, num_layers=1, patch_size=2, attention_head_dim=2, num_attention_heads=2,
    in_channels=4, cross_attention_dim=8, out_channels=8, use_additional_conditions=False,
).eval()

pipe = PixArtAlphaPipeline(None, None, AutoencoderKL().eval(), transformer, DPMSolverMultistepScheduler())
embeds = torch.randn(1, 8, 8)
mask = torch.ones(1, 8, dtype=torch.long)

pipe(prompt_embeds=embeds, prompt_attention_mask=mask, guidance_scale=1.0,
     num_inference_steps=1, use_resolution_binning=False, output_type="latent")

Relevant precedent:

# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Suggested fix:

latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Issue 2: Height/width validation allows sizes that cannot be patchified

Affected code:

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Problem:
Both pipelines only require height and width to be divisible by 8. Real PixArt latents are divided by vae_scale_factor and then by transformer.config.patch_size, so dimensions must be divisible by vae_scale_factor * patch_size unless resolution binning changes them first.

Impact:
use_resolution_binning=False accepts documented-valid sizes and then crashes inside denoising.

Reproduction:

import torch
from diffusers import AutoencoderKL, DDIMScheduler, PixArtAlphaPipeline, PixArtTransformer2DModel

transformer = PixArtTransformer2DModel(
    sample_size=8, num_layers=1, patch_size=2, attention_head_dim=2, num_attention_heads=2,
    in_channels=4, cross_attention_dim=8, out_channels=8, use_additional_conditions=False,
).eval()
vae = AutoencoderKL(
    block_out_channels=(8, 8, 8, 8),
    down_block_types=("DownEncoderBlock2D",) * 4,
    up_block_types=("UpDecoderBlock2D",) * 4,
    norm_num_groups=4,
).eval()

pipe = PixArtAlphaPipeline(None, None, vae, transformer, DDIMScheduler())
embeds = torch.randn(1, 8, 8)
mask = torch.ones(1, 8, dtype=torch.long)

pipe(prompt_embeds=embeds, prompt_attention_mask=mask, guidance_scale=1.0,
     num_inference_steps=1, height=24, width=24,
     use_resolution_binning=False, output_type="latent")

Relevant precedent:
Sana casts provided latents to the requested dtype and validates around its latent geometry:

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents

Suggested fix:

divisor = self.vae_scale_factor * self.transformer.config.patch_size
if height % divisor != 0 or width % divisor != 0:
    raise ValueError(f"`height` and `width` have to be divisible by {divisor} but are {height} and {width}.")

Issue 3: PixArtAlpha decode does not cast latents to VAE dtype

Affected code:

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:

Problem:
Alpha decodes latents / scaling_factor directly. If the transformer runs in float16 while the VAE remains float32, decode fails with an input/bias dtype mismatch. Sigma already casts latents to self.vae.dtype.

Impact:
Mixed precision component loading, partial offload workflows, or manually supplied components can fail at the final decode step.

Reproduction:

import torch
from diffusers import AutoencoderKL, DDIMScheduler, PixArtAlphaPipeline, PixArtTransformer2DModel

transformer = PixArtTransformer2DModel(
    sample_size=8, num_layers=1, patch_size=2, attention_head_dim=2, num_attention_heads=2,
    in_channels=4, cross_attention_dim=8, out_channels=8, use_additional_conditions=False,
).eval().to(dtype=torch.float16)

pipe = PixArtAlphaPipeline(None, None, AutoencoderKL().eval(), transformer, DDIMScheduler())
embeds = torch.randn(1, 8, 8, dtype=torch.float16)
mask = torch.ones(1, 8, dtype=torch.long)

pipe(prompt_embeds=embeds, prompt_attention_mask=mask, guidance_scale=1.0,
     num_inference_steps=1, use_resolution_binning=False, output_type="np")

Relevant precedent:

if not output_type == "latent":
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]

Suggested fix:

image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]

Issue 4: PixArt aspect-ratio constants are imported only in the non-lazy branch

Affected code:

_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
PixArtAlphaPipeline,
)
from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline

Problem:
The TYPE_CHECKING/slow-import branch imports ASPECT_RATIO_256_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_1024_BIN, and ASPECT_RATIO_2048_BIN, but _import_structure exposes only the two pipeline classes. Normal lazy imports fail for names that the non-lazy branch advertises.

Impact:
Public subpackage imports are inconsistent across lazy and slow import modes.

Reproduction:

from diffusers.pipelines.pixart_alpha import ASPECT_RATIO_1024_BIN

Relevant precedent:
The same file already imports these constants in the eager branch.

Suggested fix:

_import_structure["pipeline_pixart_alpha"] = [
    "ASPECT_RATIO_256_BIN",
    "ASPECT_RATIO_512_BIN",
    "ASPECT_RATIO_1024_BIN",
    "PixArtAlphaPipeline",
]
_import_structure["pipeline_pixart_sigma"] = ["ASPECT_RATIO_2048_BIN", "PixArtSigmaPipeline"]

Issue 5: Existing related issue: PixArtTransformer QKV unfuse state is not safe

Affected code:

self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

Problem:
original_attn_processors is only created inside fuse_qkv_projections(). Calling unfuse_qkv_projections() before fuse_qkv_projections() raises AttributeError, and calling fuse twice overwrites the saved original processors with fused processors. Related existing issue: #13592 reports the same copied QKV state bug for UNet.

Impact:
The public attention optimization API is not idempotent and can leave PixArt fused after an enable-twice-disable flow.

Reproduction:

from diffusers import PixArtTransformer2DModel

model = PixArtTransformer2DModel(
    sample_size=8, num_layers=1, attention_head_dim=2, num_attention_heads=2,
    cross_attention_dim=8, num_embeds_ada_norm=8, use_additional_conditions=False,
)

model.unfuse_qkv_projections()  # AttributeError

model.fuse_qkv_projections()
model.fuse_qkv_projections()
model.unfuse_qkv_projections()
print({p.__class__.__name__ for p in model.attn_processors.values()})  # still fused

Relevant precedent:
The method doc says it disables fused projection “if enabled”.

Suggested fix:

# in __init__
self.original_attn_processors = None

# in fuse_qkv_projections
if self.original_attn_processors is None:
    self.original_attn_processors = self.attn_processors

# in unfuse_qkv_projections
if self.original_attn_processors is not None:
    self.set_attn_processor(self.original_attn_processors)
    self.original_attn_processors = None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions