Skip to content

Commit f3c6242

Browse files
committed
[qwen-image] fix pr comments
1 parent e630e6e commit f3c6242

2 files changed

Lines changed: 32 additions & 89 deletions

File tree

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,23 @@ def apply_rotary_emb_qwen(
143143

144144

145145
class QwenTimestepProjEmbeddings(nn.Module):
146-
def __init__(self, embedding_dim, additional_t_cond=False):
146+
def __init__(self, embedding_dim, use_additional_t_cond=False):
147147
super().__init__()
148148

149149
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
150150
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
151-
self.additional_t_cond = additional_t_cond
152-
if additional_t_cond:
151+
self.use_additional_t_cond = use_additional_t_cond
152+
if use_additional_t_cond:
153153
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
154-
self.addition_t_embedding.weight.data.zero_()
155154

156155
def forward(self, timestep, hidden_states, addition_t_cond=None):
157156
timesteps_proj = self.time_proj(timestep)
158157
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
159158

160159
conditioning = timesteps_emb
161-
if self.additional_t_cond:
162-
assert addition_t_cond is not None, "When additional_t_cond is True, addition_t_cond must be provided."
160+
if self.use_additional_t_cond:
161+
if addition_t_cond is None:
162+
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
163163
addition_t_emb = self.addition_t_embedding(addition_t_cond)
164164
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
165165
conditioning = conditioning + addition_t_emb
@@ -291,9 +291,7 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
291291
],
292292
dim=1,
293293
)
294-
self.rope_cache = {}
295294

296-
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
297295
self.scale_rope = scale_rope
298296

299297
def rope_params(self, index, dim, theta=10000):
@@ -703,7 +701,7 @@ def __init__(
703701
guidance_embeds: bool = False, # TODO: this should probably be removed
704702
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
705703
zero_cond_t: bool = False,
706-
additional_t_cond: bool = False,
704+
use_additional_t_cond: bool = False,
707705
use_layer3d_rope: bool = False,
708706
):
709707
super().__init__()
@@ -716,7 +714,7 @@ def __init__(
716714
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
717715

718716
self.time_text_embed = QwenTimestepProjEmbeddings(
719-
embedding_dim=self.inner_dim, additional_t_cond=additional_t_cond
717+
embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond
720718
)
721719

722720
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py

Lines changed: 24 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818

1919
import numpy as np
2020
import torch
21-
from PIL import Image
2221
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
2322

2423
from ...image_processor import PipelineImageInput, VaeImageProcessor
2524
from ...loaders import QwenImageLoraLoaderMixin
2625
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
2726
from ...schedulers import FlowMatchEulerDiscreteScheduler
28-
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
27+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2928
from ...utils.torch_utils import randn_tensor
3029
from ..pipeline_utils import DiffusionPipeline
3130
from .pipeline_output import QwenImagePipelineOutput
@@ -152,14 +151,15 @@ def retrieve_latents(
152151
raise AttributeError("Could not access latents of provided encoder_output")
153152

154153

154+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions
155155
def calculate_dimensions(target_area, ratio):
156156
width = math.sqrt(target_area * ratio)
157157
height = width / ratio
158158

159159
width = round(width / 32) * 32
160160
height = round(height / 32) * 32
161161

162-
return width, height, None
162+
return width, height
163163

164164

165165
class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
@@ -266,6 +266,7 @@ def _get_qwen_prompt_embeds(
266266

267267
return prompt_embeds, encoder_attention_mask
268268

269+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
269270
def encode_prompt(
270271
self,
271272
prompt: Union[str, List[str]],
@@ -296,6 +297,9 @@ def encode_prompt(
296297
if prompt_embeds is None:
297298
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
298299

300+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
301+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
302+
299303
_, seq_len, _ = prompt_embeds.shape
300304
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
301305
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -393,6 +397,7 @@ def _unpack_latents(latents, height, width, layers, vae_scale_factor):
393397

394398
return latents
395399

400+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
396401
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
397402
if isinstance(generator, list):
398403
image_latents = [
@@ -416,59 +421,6 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
416421

417422
return image_latents
418423

419-
def enable_vae_slicing(self):
420-
r"""
421-
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
422-
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
423-
"""
424-
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
425-
deprecate(
426-
"enable_vae_slicing",
427-
"0.40.0",
428-
depr_message,
429-
)
430-
self.vae.enable_slicing()
431-
432-
def disable_vae_slicing(self):
433-
r"""
434-
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
435-
computing decoding in one step.
436-
"""
437-
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
438-
deprecate(
439-
"disable_vae_slicing",
440-
"0.40.0",
441-
depr_message,
442-
)
443-
self.vae.disable_slicing()
444-
445-
def enable_vae_tiling(self):
446-
r"""
447-
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
448-
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
449-
processing larger images.
450-
"""
451-
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
452-
deprecate(
453-
"enable_vae_tiling",
454-
"0.40.0",
455-
depr_message,
456-
)
457-
self.vae.enable_tiling()
458-
459-
def disable_vae_tiling(self):
460-
r"""
461-
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
462-
computing decoding in one step.
463-
"""
464-
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
465-
deprecate(
466-
"disable_vae_tiling",
467-
"0.40.0",
468-
depr_message,
469-
)
470-
self.vae.disable_tiling()
471-
472424
def prepare_latents(
473425
self,
474426
image,
@@ -560,8 +512,6 @@ def __call__(
560512
prompt: Union[str, List[str]] = None,
561513
negative_prompt: Union[str, List[str]] = None,
562514
true_cfg_scale: float = 4.0,
563-
height: Optional[int] = None,
564-
width: Optional[int] = None,
565515
layers: Optional[int] = 4,
566516
num_inference_steps: int = 50,
567517
sigmas: Optional[List[float]] = None,
@@ -607,10 +557,6 @@ def __call__(
607557
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
608558
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
609559
lower image quality.
610-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
611-
The height in pixels of the generated image. This is set to 1024 by default for the best results.
612-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
613-
The width in pixels of the generated image. This is set to 1024 by default for the best results.
614560
num_inference_steps (`int`, *optional*, defaults to 50):
615561
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
616562
expense of slower inference.
@@ -663,7 +609,7 @@ def __call__(
663609
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
664610
`._callback_tensor_inputs` attribute of your pipeline class.
665611
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
666-
resolution (`int`, *optional*, defaults to 640)
612+
resolution (`int`, *optional*, defaults to 640):
667613
using different bucket in (640, 1024) to determin the condition and output resolution
668614
cfg_normalize (`bool`, *optional*, defaults to `False`)
669615
whether enable cfg normalization.
@@ -679,7 +625,7 @@ def __call__(
679625
"""
680626
image_size = image[0].size if isinstance(image, list) else image.size
681627
assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}"
682-
calculated_width, calculated_height, _ = calculate_dimensions(
628+
calculated_width, calculated_height = calculate_dimensions(
683629
resolution * resolution, image_size[0] / image_size[1]
684630
)
685631
height = calculated_height
@@ -718,9 +664,6 @@ def __call__(
718664

719665
if prompt is None or prompt == "" or prompt == " ":
720666
prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device)
721-
print(f"Generated prompt: {prompt}")
722-
else:
723-
print(f"User prompt: {prompt}")
724667

725668
# 3. Define call parameters
726669
if prompt is not None and isinstance(prompt, str):
@@ -917,19 +860,21 @@ def __call__(
917860
latents.device, latents.dtype
918861
)
919862
latents = latents / latents_std + latents_mean
920-
latents = torch.unbind(latents, 2)
921-
image = []
922-
for z in latents[1:]:
923-
z = z.unsqueeze(2) # b c f h w
924-
image.append(self.vae.decode(z, return_dict=False)[0])
925-
926-
image = torch.cat(image, dim=2) # b c f h w
927-
image = image.permute(0, 2, 3, 4, 1) # b f h w c
928-
image = (image * 0.5 + 0.5).clamp(0, 1).cpu().float().numpy()
929-
image = (image * 255).round().astype("uint8")
863+
864+
b, c, f, h, w = latents.shape
865+
866+
latents = latents[:, :, 1:] # remove the first frame as it is the orgin input
867+
868+
latents = latents.permute(0, 2, 1, 3, 4).view(-1, c, 1, h, w)
869+
870+
image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w
871+
872+
image = image.squeeze(2)
873+
874+
image = self.image_processor.postprocess(image, output_type=output_type)
930875
images = []
931-
for layers in image:
932-
images.append([Image.fromarray(layer) for layer in layers])
876+
for bidx in range(b):
877+
images.append(image[bidx * f : (bidx + 1) * f])
933878

934879
# Offload all models
935880
self.maybe_free_model_hooks()

0 commit comments

Comments
 (0)