Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImagePipeline",
"QwenImageLayeredPipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaImageToVideoPipeline",
Expand Down Expand Up @@ -1272,6 +1273,7 @@
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImageLayeredPipeline,
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
Expand Down
11 changes: 7 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def __init__(
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
input_channels=3,
non_linearity: str = "silu",
):
super().__init__()
Expand All @@ -410,7 +411,7 @@ def __init__(
scale = 1.0

# init block
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)

# downsample blocks
self.down_blocks = nn.ModuleList([])
Expand Down Expand Up @@ -570,6 +571,7 @@ def __init__(
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
input_channels=3,
non_linearity: str = "silu",
):
super().__init__()
Expand Down Expand Up @@ -621,7 +623,7 @@ def __init__(

# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1)

self.gradient_checkpointing = False

Expand Down Expand Up @@ -684,6 +686,7 @@ def __init__(
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
input_channels: int = 3,
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
) -> None:
Expand All @@ -695,13 +698,13 @@ def __init__(
self.temperal_upsample = temperal_downsample[::-1]

self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)

self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels
)

self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
Expand Down
143 changes: 137 additions & 6 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,26 @@ def apply_rotary_emb_qwen(


class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim):
def __init__(self, embedding_dim, use_additional_t_cond=False):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_t_cond = use_additional_t_cond
if use_additional_t_cond:
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
Comment on lines +152 to +153
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better.


def forward(self, timestep, hidden_states):
def forward(self, timestep, hidden_states, addition_t_cond=None):
Comment thread
naykun marked this conversation as resolved.
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)

conditioning = timesteps_emb
if self.use_additional_t_cond:
if addition_t_cond is None:
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
addition_t_emb = self.addition_t_embedding(addition_t_cond)
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
conditioning = conditioning + addition_t_emb

return conditioning

Expand Down Expand Up @@ -259,6 +268,120 @@ def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0
return freqs.clone().contiguous()


class QwenEmbedLayer3DRope(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
)

self.scale_rope = scale_rope

def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)

if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
if not isinstance(video_fhw, list):
video_fhw = [video_fhw]

vid_freqs = []
max_vid_index = 0
layer_num = len(video_fhw) - 1
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
if idx != layer_num:
video_freq = self._compute_video_freqs(frame, height, width, idx)
else:
### For the condition image, we set the layer index to -1
video_freq = self._compute_condition_freqs(frame, height, width)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)

if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)

max_vid_index = max(max_vid_index, layer_num)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()

@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)

freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()


class QwenDoubleStreamAttnProcessor2_0:
"""
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
Expand Down Expand Up @@ -578,14 +701,21 @@ def __init__(
guidance_embeds: bool = False, # TODO: this should probably be removed
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool = False,
use_additional_t_cond: bool = False,
use_layer3d_rope: bool = False,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
if not use_layer3d_rope:
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
else:
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
Comment thread
sayakpaul marked this conversation as resolved.

self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond
)

self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)

Expand Down Expand Up @@ -621,6 +751,7 @@ def forward(
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
additional_t_cond=None,
Comment thread
sayakpaul marked this conversation as resolved.
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
Expand Down Expand Up @@ -683,9 +814,9 @@ def forward(
guidance = guidance.to(hidden_states.dtype) * 1000

temb = (
self.time_text_embed(timestep, hidden_states)
self.time_text_embed(timestep, hidden_states, additional_t_cond)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@
"QwenImageEditInpaintPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageLayeredPipeline",
]
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
try:
Expand Down Expand Up @@ -764,6 +765,7 @@
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImageLayeredPipeline,
QwenImagePipeline,
)
from .sana import (
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/qwenimage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
_import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
_import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
_import_structure["pipeline_qwenimage_layered"] = ["QwenImageLayeredPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -47,6 +48,7 @@
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
from .pipeline_qwenimage_layered import QwenImageLayeredPipeline
else:
import sys

Expand Down
Loading
Loading