diff --git a/scripts/convert_joyimage_edit_to_diffusers.py b/scripts/convert_joyimage_edit_to_diffusers.py new file mode 100644 index 000000000000..3ad23de8f462 --- /dev/null +++ b/scripts/convert_joyimage_edit_to_diffusers.py @@ -0,0 +1,366 @@ +import argparse +from typing import Any, Dict, Tuple + +import torch +from accelerate import init_empty_weights +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration + +from diffusers import ( + AutoencoderKLWan, + JoyImageEditPipeline, + JoyImageEditTransformer3DModel, +) +from diffusers.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) + + +# This code is modified from convert_wan_to_diffusers.py to support input ckpt path +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in old_state_dict.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + new_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + new_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + new_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + new_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + new_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + new_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + new_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + new_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace( + f"decoder.upsamples.{block_idx}", + "decoder.up_blocks.0.upsamplers.0", + ) + elif block_idx == 7: + new_key = key.replace( + f"decoder.upsamples.{block_idx}", + "decoder.up_blocks.1.upsamplers.0", + ) + elif block_idx == 11: + new_key = key.replace( + f"decoder.upsamples.{block_idx}", + "decoder.up_blocks.2.upsamplers.0", + ) + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + # Keep other keys unchanged + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + + +def get_transformer_config() -> Tuple[Dict[str, Any], ...]: + config = { + "diffusers_config": { + "hidden_size": 4096, + "in_channels": 16, + "num_attention_heads": 32, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_dim": 4096, + "rope_type": "rope", + "theta": 10000, + }, + } + return config + + +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + + # Attention weights moved from block to block.attn submodule + attn_suffixes = ( + "img_attn_qkv.", + "img_attn_q_norm.", + "img_attn_k_norm.", + "img_attn_proj.", + "txt_attn_qkv.", + "txt_attn_q_norm.", + "txt_attn_k_norm.", + "txt_attn_proj.", + ) + remapped = {} + for key, value in original_state_dict.items(): + new_key = key + if key.startswith("double_blocks."): + for suffix in attn_suffixes: + # double_blocks.0.img_attn_qkv.weight -> double_blocks.0.attn.img_attn_qkv.weight + if "." + suffix in key and ".attn." + suffix not in key: + new_key = key.replace("." + suffix, ".attn." + suffix) + break + remapped[new_key] = value + + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditTransformer3DModel(**config["diffusers_config"]) + transformer.load_state_dict(remapped, strict=True, assign=True) + return transformer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", + type=str, + default=None, + help="Path to original transformer checkpoint", + ) + parser.add_argument( + "--vae_ckpt_path", + type=str, + default=None, + help="Path to original VAE checkpoint", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path to original llama checkpoint", + ) + parser.add_argument( + "--tokenizer_path", + type=str, + default=None, + help="Path to original llama tokenizer", + ) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Path where converted model should be saved", + ) + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument("--flow_shift", type=float, default=7.0) + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + # assert args.tokenizer_path is not None + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained( + args.text_encoder_path, torch_dtype=torch.bfloat16 + ).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + flow_shift = 1.5 + scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=flow_shift) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c9caea09d8a4..fd9781088b6e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -252,6 +252,7 @@ "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", + "JoyImageEditTransformer3DModel", "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", @@ -580,6 +581,8 @@ "IFPipeline", "IFSuperResolutionPipeline", "ImageTextPipelineOutput", + "JoyImageEditPipeline", + "JoyImageEditPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", "Kandinsky5I2IPipeline", @@ -1069,6 +1072,7 @@ HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, I2VGenXLUNet, + JoyImageEditTransformer3DModel, Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, @@ -1372,6 +1376,8 @@ IFPipeline, IFSuperResolutionPipeline, ImageTextPipelineOutput, + JoyImageEditPipeline, + JoyImageEditPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, Kandinsky5I2IPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc772fcc6d0c..65a4f744a8b9 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -113,6 +113,9 @@ _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] + _import_structure["transformers.transformer_joyimage"] = [ + "JoyImageEditTransformer3DModel", + ] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] @@ -236,6 +239,7 @@ HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + JoyImageEditTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatAudioDiTTransformer, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index bbd7ecfa911b..5c64b5fc99fa 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -36,6 +36,7 @@ from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel + from .transformer_joyimage import JoyImageEditTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_longcat_image import LongCatImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py new file mode 100644 index 000000000000..3a8e496d1218 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -0,0 +1,589 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Tuple + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# --------------------------------------------------------------------------- +# Rotary position embedding utilities +# --------------------------------------------------------------------------- + + +def _apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + ndim = xq.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)] + cos = freqs_cis[0].view(*shape).to(xq.device) + sin = freqs_cis[1].view(*shape).to(xq.device) + + def _rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq) + xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk) + return xq_out, xk_out + + +# --------------------------------------------------------------------------- +# Modulation +# --------------------------------------------------------------------------- + + +class JoyImageModulate(nn.Module): + """Wan-style learnable modulation table. + + Produces `factor` modulation vectors by adding the conditioning signal to a learnable parameter table. + """ + + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, + requires_grad=True, + ) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + if x.ndim != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + +# --------------------------------------------------------------------------- +# Attention processor +# --------------------------------------------------------------------------- + + +class JoyImageAttnProcessor: + """Attention processor for JoyImage double-stream joint attention. + + Implements the joint attention computation where text and image streams are processed together. The + :class:`JoyImageAttention` module stores fused QKV projections (``img_attn_qkv`` / ``txt_attn_qkv``). + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + pass + + def __call__( + self, + attn: "JoyImageAttention", + hidden_states: torch.Tensor, # image stream (B, S_img, D) + encoder_hidden_states: torch.Tensor = None, # text stream (B, S_txt, D) + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if encoder_hidden_states is None: + raise ValueError("JoyImageAttnProcessor requires encoder_hidden_states (text stream)") + + heads = attn.heads + + # image stream: fused QKV -> split + img_qkv = attn.img_attn_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + + # text stream: fused QKV -> split + txt_qkv = attn.txt_attn_qkv(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + + # reshape to multi-head: (B, S, H, D) + img_query = img_query.unflatten(-1, (heads, -1)) + img_key = img_key.unflatten(-1, (heads, -1)) + img_value = img_value.unflatten(-1, (heads, -1)) + + txt_query = txt_query.unflatten(-1, (heads, -1)) + txt_key = txt_key.unflatten(-1, (heads, -1)) + txt_value = txt_value.unflatten(-1, (heads, -1)) + + # QK norm + img_query = attn.img_attn_q_norm(img_query) + img_key = attn.img_attn_k_norm(img_key) + txt_query = attn.txt_attn_q_norm(txt_query) + txt_key = attn.txt_attn_k_norm(txt_key) + + # RoPE (custom implementation) + if image_rotary_emb is not None: + vis_freqs, txt_freqs = image_rotary_emb + if vis_freqs is not None: + img_query, img_key = _apply_rotary_emb(img_query, img_key, vis_freqs) + if txt_freqs is not None: + txt_query, txt_key = _apply_rotary_emb(txt_query, txt_key, txt_freqs) + + # concatenate for joint attention: [img, txt] + joint_query = torch.cat([img_query, txt_query], dim=1) + joint_key = torch.cat([img_key, txt_key], dim=1) + joint_value = torch.cat([img_value, txt_value], dim=1) + + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # split back + img_attn_output = joint_hidden_states[:, : hidden_states.shape[1], :] + txt_attn_output = joint_hidden_states[:, hidden_states.shape[1] :, :] + + # output projections + img_attn_output = attn.img_attn_proj(img_attn_output) + txt_attn_output = attn.txt_attn_proj(txt_attn_output) + + return img_attn_output, txt_attn_output + + +# --------------------------------------------------------------------------- +# Attention module +# --------------------------------------------------------------------------- + + +class JoyImageAttention(nn.Module, AttentionModuleMixin): + """Joint attention module for JoyImage double-stream blocks. + + Wraps the fused QKV projections, QK norms, and output projections for both image and text streams. Delegates the + actual attention computation to a pluggable :class:`JoyImageAttnProcessor`. + """ + + _default_processor_cls = JoyImageAttnProcessor + _available_processors = [JoyImageAttnProcessor] + _supports_qkv_fusion = False + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + processor=None, + ): + super().__init__() + + self.heads = num_attention_heads + self.head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.img_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + self.txt_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) + self.txt_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_k_norm = nn.RMSNorm(attention_head_dim, eps=eps) + self.txt_attn_proj = nn.Linear(inner_dim, dim, bias=True) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by " + f"{self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, image_rotary_emb, **kwargs) + + +# --------------------------------------------------------------------------- +# Transformer block +# --------------------------------------------------------------------------- + + +class JoyImageTransformerBlock(nn.Module): + """Double-stream transformer block for JoyImage. + + Each block processes an image stream and a text stream jointly through shared attention, following the SD3 / Flux + double-stream pattern with WAN-style modulation. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: float = 4.0, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + mlp_hidden_dim = int(dim * mlp_width_ratio) + + # image stream + self.img_mod = JoyImageModulate(dim, factor=6) + self.img_norm1 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_norm2 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.img_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # text stream + self.txt_mod = JoyImageModulate(dim, factor=6) + self.txt_norm1 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_norm2 = FP32LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward(dim, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + # ---- joint attention ---- + self.attn = JoyImageAttention(dim, num_attention_heads, attention_head_dim, eps=eps) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # modulation + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(temb) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(temb) + + # --- attention --- + img_normed = self.img_norm1(hidden_states) + txt_normed = self.txt_norm1(encoder_hidden_states) + img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) + txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1) + + img_attn, txt_attn = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1) + + # --- FFN --- + img_ffn_normed = self.img_norm2(hidden_states) + txt_ffn_normed = self.txt_norm2(encoder_hidden_states) + img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1) + txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1) + img_ffn_output = self.img_mlp(img_ffn_input) + txt_ffn_output = self.txt_mlp(txt_ffn_input) + hidden_states = hidden_states + img_ffn_output * img_mod2_gate.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + txt_ffn_output * txt_mod2_gate.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class JoyImageTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + return temb, timestep_proj, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + + +class JoyImageEditTransformer3DModel(ModelMixin, ConfigMixin, AttentionMixin): + """JoyImage Transformer model for image generation / editing. + + Dual-stream DiT architecture with WAN-style conditioning embeddings and custom rotary position embeddings. + """ + + _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"] + _no_split_modules = ["JoyImageTransformerBlock"] + _supports_gradient_checkpointing = True + _keep_in_fp32_modules = [ + "time_embedder", + "norm1", + "norm2", + "norm_out", + ] + _repeated_blocks = ["JoyImageTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: list = [1, 2, 2], + in_channels: int = 16, + out_channels: int | None = None, + hidden_size: int = 3072, + num_attention_heads: int = 24, + text_dim: int = 4096, + mlp_width_ratio: float = 4.0, + num_layers: int = 20, + rope_dim_list: list[int] = [16, 56, 56], + rope_type: str = "rope", + theta: int = 256, + ): + super().__init__() + + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.rope_dim_list = rope_dim_list + self.rope_type = rope_type + self.theta = theta + + attention_head_dim = hidden_size // num_attention_heads + if hidden_size % num_attention_heads != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})" + ) + + # image projection + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + # condition embedder + self.condition_embedder = JoyImageTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6, + text_embed_dim=text_dim, + ) + + # double-stream blocks + self.double_blocks = nn.ModuleList( + [ + JoyImageTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + ) + for _ in range(num_layers) + ] + ) + + # output head + self.norm_out = FP32LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size)) + + self.gradient_checkpointing = False + + # ------------------------------------------------------------------ + # RoPE helper + # ------------------------------------------------------------------ + + def get_rotary_pos_embed( + self, + vis_rope_size: list[int], + txt_rope_size: int | None = None, + ): + target_ndim = 3 + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + list(vis_rope_size) + + head_dim = self.hidden_size // self.num_attention_heads + rope_dim_list = self.rope_dim_list + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + if sum(rope_dim_list) != head_dim: + raise ValueError("sum(rope_dim_list) should equal head_dim") + + # Build a 3-D meshgrid [0, size) for each spatial axis + grid = torch.stack( + torch.meshgrid( + *[torch.linspace(0, s, s + 1, dtype=torch.float32)[:s] for s in vis_rope_size], + indexing="ij", + ), + dim=0, + ) + + # Per-axis 1-D rotary embeddings -> concat + vis_cos, vis_sin = [], [] + for i, dim in enumerate(rope_dim_list): + pos = grid[i].reshape(-1) + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + freqs = torch.outer(pos.float(), freqs) + vis_cos.append(freqs.cos().repeat_interleave(2, dim=1)) + vis_sin.append(freqs.sin().repeat_interleave(2, dim=1)) + vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1)) + + if txt_rope_size is None: + return vis_freqs, None + + # Text positions start right after the largest visual index + grid_txt = torch.arange(txt_rope_size) + grid.view(-1).max().item() + 1 + txt_cos, txt_sin = [], [] + for i, dim in enumerate(rope_dim_list): + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + freqs = torch.outer(grid_txt.float(), freqs) + txt_cos.append(freqs.cos().repeat_interleave(2, dim=1)) + txt_sin.append(freqs.sin().repeat_interleave(2, dim=1)) + txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1)) + + return vis_freqs, txt_freqs + + # ------------------------------------------------------------------ + # Unpatchify + # ------------------------------------------------------------------ + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: + c = self.out_channels + pt, ph, pw = self.patch_size + if t * h * w != x.shape[1]: + raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})") + + x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c) + x = x.permute(0, 7, 1, 4, 2, 5, 3, 6) # nthwopqc -> nctohpwq + return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + return_dict: bool = True, + ): + # handle multi-item input (b, n, c, t, h, w) + is_multi_item = hidden_states.ndim == 6 + num_items = 0 + if is_multi_item: + num_items = hidden_states.shape[1] + if num_items > 1: + if self.patch_size[0] != 1: + raise ValueError("For multi-item input, patch_size[0] must be 1") + hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) + # rearrange: (b, n, c, t, h, w) -> (b, c, n*t, h, w) + b, n, c, t, h, w = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t, h, w) + + batch_size, _, ot, oh, ow = hidden_states.shape + tt = ot // self.patch_size[0] + th = oh // self.patch_size[1] + tw = ow // self.patch_size[2] + + # patchify + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + + # condition embeddings + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + + # RoPE + vis_freqs, txt_freqs = self.get_rotary_pos_embed( + vis_rope_size=[tt, th, tw], + txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + ) + + # main loop + for block in self.double_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img, txt = self._gradient_checkpointing_func(block, img, txt, vec, (vis_freqs, txt_freqs)) + else: + img, txt = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=(vis_freqs, txt_freqs), + ) + + # final layer + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + + # un-multi-item: (b, c, n*t, h, w) -> (b, n, c, t, h, w) + if is_multi_item: + c_out = img.shape[1] + img = img.reshape(batch_size, c_out, num_items, -1, oh, ow) + img = img.permute(0, 2, 1, 3, 4, 5) # (b, n, c, t, h, w) + if num_items > 1: + img = torch.cat([img[:, 1:], img[:, :1]], dim=1) + + if not return_dict: + return (img,) + return Transformer2DModelOutput(sample=img) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index c49ad3938cdc..f0fc7585bf31 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -333,6 +333,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] + _import_structure["joyimage"] = ["JoyImageEditPipeline", "JoyImageEditPipelineOutput"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -727,6 +728,7 @@ ) from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline + from .joyimage import JoyImageEditPipeline, JoyImageEditPipelineOutput from .kandinsky import ( KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py new file mode 100644 index 000000000000..85b9246b22a6 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] + + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_joyimage_edit import JoyImageEditPipeline + from .pipeline_output import JoyImageEditPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/joyimage/image_processor.py b/src/diffusers/pipelines/joyimage/image_processor.py new file mode 100644 index 000000000000..3aa7da1a0dcc --- /dev/null +++ b/src/diffusers/pipelines/joyimage/image_processor.py @@ -0,0 +1,149 @@ +# Copyright 2025 The JoyImage Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +from PIL import Image + +from ...configuration_utils import register_to_config +from ...image_processor import VaeImageProcessor + + +# fmt: off +BUCKETS = { + 1024: [ + (512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048), + (576, 1600), (576, 1664), (576, 1728), (576, 1792), + (640, 1472), (640, 1536), (640, 1600), + (704, 1344), (704, 1408), (704, 1472), + (768, 1216), (768, 1280), (768, 1344), + (832, 1152), (832, 1216), + (896, 1088), (896, 1152), + (960, 1024), (960, 1088), + (1024, 960), (1024, 1024), + (1088, 896), (1088, 960), + (1152, 832), (1152, 896), + (1216, 768), (1216, 832), + (1280, 768), + (1344, 704), (1344, 768), + (1408, 704), + (1472, 640), (1472, 704), + (1536, 640), + (1600, 576), (1600, 640), + (1664, 576), + (1728, 576), + (1792, 512), (1792, 576), + (1856, 512), + (1920, 512), + (1984, 512), + (2048, 512), + ], +} +# fmt: on + + +def find_best_bucket(height: int, width: int, basesize: int) -> Tuple[int, int]: + """Return the (h, w) bucket whose aspect ratio is closest to height/width.""" + target_ratio = height / width + return min( + BUCKETS[basesize], + key=lambda hw: abs(hw[0] / hw[1] - target_ratio), + ) + + +class JoyImageEditImageProcessor(VaeImageProcessor): + """ + Image processor for the JoyImage Edit pipeline. + + Handles bucket-based resolution selection and resize-center-crop preprocessing. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image. + vae_scale_factor (`int`, *optional*, defaults to `8`): + VAE spatial scale factor. + basesize (`int`, *optional*, defaults to `1024`): + Base resolution for bucket generation. + resample (`str`, *optional*, defaults to `bilinear`): + Resampling filter for resizing. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. + do_convert_rgb (`bool`, *optional*, defaults to `False`): + Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to `False`): + Whether to convert the images to grayscale format. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + basesize: int = 1024, + resample: str = "bilinear", + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + ): + super().__init__() + + def get_default_height_width( + self, + image: Image.Image, + height: int | None = None, + width: int | None = None, + ) -> Tuple[int, int]: + if height is not None and width is not None: + src_w, src_h = width, height + elif image is None: + src_w, src_h = self.config.basesize, self.config.basesize + elif isinstance(image, list): + src_w, src_h = image[0].size + else: + src_w, src_h = image.size + + return find_best_bucket(src_h, src_w, self.config.basesize) + + def resize_center_crop( + self, + img, + target_size: Tuple[int, int], + ): + """ + Scale image to cover target_size, then center-crop. + + Args: + img: Input PIL image or list of PIL images. + target_size: (height, width) to crop to. + + Returns: + Resized and center-cropped PIL image(s), matching the input type. + """ + if isinstance(img, list): + return [self.resize_center_crop(i, target_size) for i in img] + + w, h = img.size + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h = math.ceil(h * scale) + resize_w = math.ceil(w * scale) + img = img.resize((resize_w, resize_h), Image.BILINEAR) + left = (resize_w - bw) // 2 + top = (resize_h - bh) // 2 + img = img.crop((left, top, left + bw, top + bh)) + return img diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py new file mode 100644 index 000000000000..7cbb33a1fc7f --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -0,0 +1,890 @@ +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import ( + Qwen2Tokenizer, + Qwen3VLForConditionalGeneration, + Qwen3VLProcessor, +) + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKLWan, JoyImageEditTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import JoyImageEditImageProcessor +from .pipeline_output import JoyImageEditPipelineOutput + + +EXAMPLE_DOC_STRING = """ +Examples: + ```python + >>> import torch + >>> from diffusers import JoyImageEditPipeline + >>> from diffusers.utils import load_image + + >>> model_id = "jdopensource/JoyAI-Image-Edit-Diffusers" + >>> pipe = JoyImageEditPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/astronaut.jpg") + >>> output = pipe( + ... image=image, # pass an image for editing; omit for text-to-image generation + ... prompt="Add wings to the astronaut.", + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... generator=torch.manual_seed(0), + ... ) + >>> output.images[0].save("joyimage_edit.png") + ``` +""" + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Configure the scheduler and return its timestep sequence. + + Exactly one of ``timesteps``, ``sigmas``, or ``num_inference_steps`` should be provided to control the denoising + schedule. + + Args: + scheduler: The diffusion scheduler. + num_inference_steps: Number of denoising steps (used when neither + ``timesteps`` nor ``sigmas`` is given). + device: Target device for the timestep tensor. + timesteps: Custom discrete timesteps. + sigmas: Custom sigma values (alternative to ``timesteps``). + **kwargs: Additional keyword arguments forwarded to ``set_timesteps``. + + Returns: + Tuple of (timesteps tensor, num_inference_steps int). + + Raises: + ValueError: If both ``timesteps`` and ``sigmas`` are provided, or if the + scheduler does not support the requested schedule parameterisation. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +class JoyImageEditPipeline(DiffusionPipeline): + """ + Diffusion pipeline for image editing using the JoyImage architecture. + + The pipeline encodes text and image conditioning via a Qwen3-VL text encoder, denoises latents with a 3-D + transformer, and decodes the result with a WAN VAE. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditTransformer3DModel, + processor: Qwen3VLProcessor, + text_token_max_length: int = 2048, + ): + """ + Initialise the pipeline and register all sub-modules. + + Args: + scheduler: Noise scheduler for the denoising process. + vae: Variational autoencoder used for encoding / decoding latents. + text_encoder: Qwen3-VL multimodal language model for prompt encoding. + tokenizer: Tokenizer paired with the text encoder. + transformer: 3-D transformer denoising network. + processor: Qwen3-VL processor for multi-image prompt preparation. + text_token_max_length: Maximum number of text tokens for the encoder. + """ + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.text_token_max_length = text_token_max_length + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.vae_image_processor = JoyImageEditImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + ) + + # Prompt templates used when encoding text with / without image tokens. + self.prompt_template_encode = { + "image": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ), + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + # Number of system-prompt tokens to drop from the beginning of hidden states. + self.prompt_template_encode_start_idx = { + "image": 34, + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_last_decoder_hidden_states(self, forward_fn, **kwargs): + """ + Run ``forward_fn(**kwargs)`` while capturing the **pre-norm** output of the last decoder layer via a forward + hook. + + This model was trained on transformers 4.57, where ``Qwen3VLForConditionalGeneration``'s + ``@check_model_inputs`` decorator monkey-patched each decoder layer to collect ``hidden_states``. Because + ``Qwen3VLCausalLMOutputWithPast`` has no ``last_hidden_state`` field, ``tie_last_hidden_states`` had no effect + and ``hidden_states[-1]`` was the **pre-norm** output of the last decoder layer. + + Starting from https://github.com/huggingface/transformers/pull/42609 the CausalLM forward explicitly returns + ``hidden_states=outputs.hidden_states`` from the inner model. Combined with the subsequent + ``@check_model_inputs`` → ``@capture_outputs`` migration (transformers 5.x), ``hidden_states`` is now captured + at the ``Qwen3VLTextModel`` level where ``tie_last_hidden_states=True`` replaces ``hidden_states[-1]`` with the + **post-norm** ``last_hidden_state``. The CausalLM simply passes this through, so ``hidden_states[-1]`` becomes + post-norm – a ~10× scale difference (std ≈ 2 vs ≈ 21) that breaks inference. + + This helper bypasses both mechanisms by hooking the last decoder layer directly, returning the raw pre-norm + output regardless of the transformers version. + """ + captured = {} + + def _hook(_module, _input, output): + captured["hidden_states"] = output[0] if isinstance(output, tuple) else output + + handle = self.text_encoder.model.language_model.layers[-1].register_forward_hook(_hook) + try: + forward_fn(**kwargs) + finally: + handle.remove() + return captured["hidden_states"] + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, ...]: + """ + Extract valid (non-padded) hidden states for each sequence in the batch. + + Args: + hidden_states: Shape (B, T, D). + mask: Binary attention mask of shape (B, T). + + Returns: + Tuple of tensors, one per batch element, each of shape (valid_T, D). + """ + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths.tolist(), dim=0) + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + template_type: str = "image", + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode text prompts using the Qwen tokenizer (text-only path). + + Args: + prompt: A single prompt string or a list of prompt strings. + template_type: Key into ``prompt_template_encode`` / ``prompt_template_encode_start_idx``. + device: Target device. + dtype: Target floating-point dtype. + + Returns: + Tuple of (prompt_embeds, encoder_attention_mask) where both tensors have shape (B, max_seq_len, D) and (B, + max_seq_len) respectively, zero-padded to the same length. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + max_length=self.text_token_max_length + drop_idx, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device) + + hidden_states = self._get_last_decoder_hidden_states( + self.text_encoder, + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + ) + + # Drop system-prompt prefix tokens and re-pack into a padded batch. + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + + max_seq_len = min( + self.text_token_max_length, + max(u.size(0) for u in split_hidden_states), + max(u.size(0) for u in attn_mask_list), + ) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds, encoder_attention_mask + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + images: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + template_type: Optional[str] = "multiple_images", + max_sequence_length: Optional[int] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode prompts that contain inline image tokens via the Qwen processor. + + ``\\n`` placeholders in each prompt string are replaced by the Qwen vision special tokens before being + fed to the multimodal encoder. + + Args: + prompt: Prompt string(s), optionally containing ``\\n`` tokens. + device: Target device. + num_images_per_prompt: Number of outputs to generate per prompt. + images: Pixel tensors corresponding to the inline image tokens. + prompt_embeds: Pre-computed prompt embeddings. + prompt_embeds_mask: Attention mask for pre-computed embeddings. + template_type: Must be ``"multiple_images"``. + max_sequence_length: If set, truncate the output to this length + (keeping the last ``max_sequence_length`` tokens). + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + if template_type != "multiple_images": + raise ValueError(f"Expected template_type 'multiple_images', but got '{template_type}'") + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + prompt = [f"\n{p}" for p in prompt] + prompt = [f"<|im_start|>user\n{p}<|im_end|>\n" for p in prompt] + + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + if images is not None: + if not isinstance(images, list): + images = [images] * len(prompt) + elif len(images) < len(prompt) and len(prompt) % len(images) == 0: + images = images * (len(prompt) // len(images)) + + inputs = self.processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + last_hidden_states = self._get_last_decoder_hidden_states(self.text_encoder, **inputs) + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + template_type: str = "image", + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode a text prompt into embeddings (text-only path). + + Pre-computed ``prompt_embeds`` bypass encoding entirely. + + Args: + prompt: Prompt string or list of prompt strings. + device: Target device. + num_images_per_prompt: Number of outputs to generate per prompt. + prompt_embeds: Pre-computed prompt embeddings. + prompt_embeds_mask: Attention mask for pre-computed embeddings. + max_sequence_length: Maximum output sequence length. + template_type: Prompt template key (``"image"`` or ``"multiple_images"``). + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, template_type, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + """ + Validate pipeline inputs before the forward pass. + + Raises: + ValueError: On any invalid combination of arguments. + """ + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError("`callback_on_step_end_tensor_inputs` has invalid keys.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.") + elif prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError("`prompt` has to be of type `str` or `list`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.") + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` is required.") + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError("If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` is required.") + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Normalise latents using per-channel statistics from the VAE config. + + Uses (latent - mean) / std when the VAE exposes ``latents_mean`` and ``latents_std``; otherwise falls back to + scaling by ``scaling_factor``. + + Args: + latent: Raw latent tensor from ``vae.encode``. + + Returns: + Normalised latent tensor. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, -1, 1, 1, 1) + .to(device=latent.device, dtype=latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, -1, 1, 1, 1) + .to(device=latent.device, dtype=latent.dtype) + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Invert :meth:`normalize_latents` to recover the original latent scale. + + Args: + latent: Normalised latent tensor. + + Returns: + Latent tensor in the scale expected by ``vae.decode``. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, -1, 1, 1, 1) + .to(device=latent.device, dtype=latent.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, -1, 1, 1, 1) + .to(device=latent.device, dtype=latent.dtype) + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def prepare_latents( + self, + batch_size: int, + num_items: int, + num_channels_latents: int, + height: int, + width: int, + video_length: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.Tensor] = None, + reference_images: Optional[List[Image.Image]] = None, + enable_denormalization: bool = True, + ) -> torch.Tensor: + """ + Prepare the initial noisy latent tensor for the denoising loop. + + When ``reference_images`` is provided the first (num_items - 1) slots are filled with VAE-encoded reference + image latents; the last slot is random noise. When ``latents`` is provided it is moved to ``device`` without + modification. Otherwise pure random noise is returned. + + Args: + batch_size: Number of samples in the batch. + num_items: Number of image slots (reference + target). + num_channels_latents: Latent channel dimension from the transformer config. + height: Spatial height in pixels. + width: Spatial width in pixels. + video_length: Number of frames (1 for image inference). + dtype: Floating-point dtype for the latent tensor. + device: Target device. + generator: RNG generator(s) for reproducible sampling. + latents: Optional pre-allocated latent tensor. + reference_images: Optional list of PIL images to encode as conditioning. + enable_denormalization: Whether to normalise encoded reference latents. + + Returns: + Latent tensor of shape (B, num_items, C, T, H', W'). + + Raises: + ValueError: If ``generator`` is a list whose length differs from ``batch_size``. + """ + shape = ( + batch_size, + num_items, + num_channels_latents, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError("Generator list length must match batch size.") + + if latents is None: + if reference_images is not None: + if batch_size > len(reference_images) and batch_size % len(reference_images) == 0: + reference_images = reference_images * (batch_size // len(reference_images)) + elif batch_size > len(reference_images): + raise ValueError( + f"Cannot duplicate `image` of batch size {len(reference_images)} to {batch_size} text prompts." + ) + # Encode reference images and concatenate with a noise slot. + ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images] + ref_img = torch.stack(ref_img).to(device=device, dtype=dtype) + ref_img = ref_img / 127.5 - 1.0 + ref_img = ref_img.permute(0, 3, 1, 2).unsqueeze(2) + ref_vae = self.vae.encode(ref_img).latent_dist.sample() + if enable_denormalization: + ref_vae = self.normalize_latents(ref_vae) + ref_vae = ref_vae.view(shape[0], num_items - 1, *ref_vae.shape[1:]) + noise = randn_tensor( + (shape[0], 1, *shape[2:]), + generator=generator, + device=device, + dtype=dtype, + ) + latents = torch.cat([ref_vae, noise], dim=1) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + # ------------------------------------------------------------------ + # Pipeline properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + """Classifier-free guidance scale used in the current forward pass.""" + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + """True when guidance_scale > 1, enabling classifier-free guidance.""" + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + """Total number of denoising timesteps in the current forward pass.""" + return self._num_timesteps + + @property + def interrupt(self) -> bool: + """When True, the denoising loop is interrupted at the next step.""" + return self._interrupt + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 40, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 4096, + enable_denormalization: bool = True, + ): + r""" + Generate an edited image conditioned on a reference image and a text prompt. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide generation. + height (`int`): + Height of the generated output in pixels. + width (`int`): + Width of the generated output in pixels. + image (`PipelineImageInput`, *optional*): + Reference image used for conditioning. When provided the pipeline operates in image-editing mode with + ``num_items=2``. + num_inference_steps (`int`, *optional*, defaults to 40): + Number of denoising steps. More steps generally improve quality at the cost of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps for the denoising process. When provided, ``num_inference_steps`` is inferred from the + list length. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising process. Mutually exclusive with ``timesteps``. + guidance_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt(s) used to suppress undesired content. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of generated samples per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + RNG generator(s) for deterministic sampling. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. Sampled from a Gaussian distribution when not provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed prompt embeddings. When provided ``prompt`` can be omitted. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``prompt_embeds``. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative prompt embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``negative_prompt_embeds``. + output_type (`str`, *optional*, defaults to ``"pil"``): + Output format. Pass ``"latent"`` to return raw latents. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a :class:`JoyImageEditPipelineOutput` or a plain tensor. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + Callback invoked at the end of each denoising step with signature ``(self, step: int, timestep: int, + callback_kwargs: Dict)``. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*, defaults to ``["latents"]``): + Tensor keys included in ``callback_kwargs`` for ``callback_on_step_end``. + max_sequence_length (`int`, *optional*, defaults to 4096): + Maximum sequence length for prompt encoding. + enable_denormalization (`bool`, *optional*, defaults to `True`): + Denormalise latents before VAE decoding. + + Examples: + + Returns: + [`~pipelines.joyimage.JoyImageEditPipelineOutput`] or `torch.Tensor`: + If ``return_dict`` is ``True``, returns a pipeline output object containing the generated image(s). + Otherwise returns the image tensor directly. + """ + # Resize the input image to the nearest bucket resolution. + # Or resize the specified height and width to the nearest bucket resolution. + height, width = self.vae_image_processor.get_default_height_width(image, height, width) + processed_image = None + if image is not None: + processed_image = self.vae_image_processor.resize_center_crop(image, (height, width)) + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # num_items: 1 for unconditional generation, 2 for reference-image editing. + num_items = 1 if image is None else 2 + + # Encode the conditioning prompt. + if processed_image is not None: + prompt_embeds, prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=prompt, + images=processed_image, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + else: + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + if self.do_classifier_free_guidance: + # Build default negative prompts when none are provided. + if negative_prompt is None and negative_prompt_embeds is None: + negative_prompt = [""] * batch_size + + if processed_image is not None: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_multiple_images( + prompt=negative_prompt, + images=processed_image, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_items, + num_channels_latents, + height, + width, + 1, # video_length = 1 for image inference + prompt_embeds.dtype, + device, + generator, + latents, + reference_images=( + (processed_image if isinstance(processed_image, list) else [processed_image]) + if processed_image is not None + else None + ), + enable_denormalization=enable_denormalization, + ) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Cache reference latents to restore them at each denoising step. + if num_items > 1: + ref_latents = latents[:, : (num_items - 1)].clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference latents so they are never overwritten by the scheduler. + if num_items > 1: + latents[:, : (num_items - 1)] = ref_latents.clone() + + latent_model_input = latents + t_expand = t.repeat(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=negative_prompt_embeds, + return_dict=False, + )[0] + + comb_pred = noise_pred_uncond + self.guidance_scale * (noise_pred - noise_pred_uncond) + # Rescale to match the conditional prediction norm (guidance rescaling). + cond_norm = torch.norm(noise_pred, dim=2, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=2, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm.clamp_min(1e-6)) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if progress_bar is not None: + progress_bar.update() + + if output_type != "latent": + latents = latents.flatten(0, 1) + if enable_denormalization: + latents = self.denormalize_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = image.unflatten(0, (batch_size * num_images_per_prompt, -1)) + else: + image = latents + + # Extract the target slot (last item) from each batch element. + # (B, num_items, C, T, H, W) -> permute -> (B, num_items, T, C, H, W) -> [:, -1] -> (B, T, C, H, W) + image = image.float().permute(0, 1, 3, 2, 4, 5)[:, -1].squeeze(1) + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py new file mode 100644 index 000000000000..175dce3540d7 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class JoyImageEditPipelineOutput(BaseOutput): + """ + Output class for JoyImageEdit generation pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 60222c2b6fca..9bfb73c1999e 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1365,6 +1365,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class JoyImageEditTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class Kandinsky3UNet(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6511345e9511..570a3f4dd7c3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1967,6 +1967,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class JoyImageEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class JoyImageEditPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_joyimage.py b/tests/models/transformers/test_models_transformer_joyimage.py new file mode 100644 index 000000000000..c464a44c29b5 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_joyimage.py @@ -0,0 +1,109 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers import JoyImageEditTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) + + +enable_full_determinism() + + +class JoyImageEditTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return JoyImageEditTransformer3DModel + + @property + def output_shape(self) -> tuple[int, ...]: + return (16, 1, 4, 4) + + @property + def input_shape(self) -> tuple[int, ...]: + return (16, 1, 4, 4) + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def uses_custom_attn_processor(self) -> bool: + return True + + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": [1, 2, 2], + "in_channels": 16, + "hidden_size": 32, + "num_attention_heads": 2, + "text_dim": 16, + "num_layers": 2, + "rope_dim_list": [4, 6, 6], + "theta": 256, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + hidden_states = randn_tensor((batch_size, 16, 1, 4, 4), generator=self.generator, device=torch_device) + encoder_hidden_states = randn_tensor((batch_size, 12, 16), generator=self.generator, device=torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + +class TestJoyImageEditTransformer(JoyImageEditTransformerTesterConfig, ModelTesterMixin): + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + pytest.skip("Tolerance requirements too high for meaningful test") + + +class TestJoyImageEditTransformerMemory(JoyImageEditTransformerTesterConfig, MemoryTesterMixin): + pass + + +class TestJoyImageEditTransformerTraining(JoyImageEditTransformerTesterConfig, TrainingTesterMixin): + def test_gradient_checkpointing_is_applied(self): + expected_set = {"JoyImageEditTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestJoyImageEditTransformerAttention(JoyImageEditTransformerTesterConfig, AttentionTesterMixin): + pass + + +class TestJoyImageEditTransformerCompile(JoyImageEditTransformerTesterConfig, TorchCompileTesterMixin): + pass diff --git a/tests/pipelines/joyimage/__init__.py b/tests/pipelines/joyimage/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/joyimage/test_joyimage_edit.py b/tests/pipelines/joyimage/test_joyimage_edit.py new file mode 100644 index 000000000000..0f201ace7d28 --- /dev/null +++ b/tests/pipelines/joyimage/test_joyimage_edit.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image +from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + JoyImageEditPipeline, + JoyImageEditTransformer3DModel, +) +from diffusers.hooks import apply_group_offloading + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class JoyImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = JoyImageEditPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "image"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + super().setUp() + self._bucket_patcher = patch( + "diffusers.pipelines.joyimage.image_processor.find_best_bucket", + return_value=(32, 32), + ) + self._bucket_patcher.start() + + def tearDown(self): + self._bucket_patcher.stop() + super().tearDown() + + def get_dummy_components(self): + tiny_ckpt_id = "huangfeice/tiny-random-Qwen3VLForConditionalGeneration" + + torch.manual_seed(0) + transformer = JoyImageEditTransformer3DModel( + patch_size=[1, 2, 2], + in_channels=16, + hidden_size=32, + num_attention_heads=2, + text_dim=16, + num_layers=1, + rope_dim_list=[4, 6, 6], + theta=256, + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + processor = Qwen3VLProcessor.from_pretrained(tiny_ckpt_id) + processor.image_processor.min_pixels = 4 * 28 * 28 + processor.image_processor.max_pixels = 4 * 28 * 28 + + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(tiny_ckpt_id) + text_encoder.resize_token_embeddings(len(processor.tokenizer)) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": processor.tokenizer, + "processor": processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "a cat sitting on a bench", + "image": Image.new("RGB", (32, 32)), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + @unittest.skip("num_images_per_prompt not applicable: each prompt is bound to a reference image") + def test_num_images_per_prompt(self): + pass + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True) + def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4): + super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol) + + def test_group_offloading_inference(self): + # Qwen3VLForConditionalGeneration (the text encoder) is incompatible with leaf_level group + # offloading. Its Qwen3VLVisionModel.fast_pos_embed_interpolate reads + # `self.pos_embed.weight.device` to create intermediate tensors before the Embedding's + # pre_forward hook fires, so the intermediate tensors land on CPU while hidden_states + # (produced by the Conv3d patch_embed) land on CUDA, causing a device mismatch. + # + # block_level works correctly: since Qwen3VLForConditionalGeneration has no ModuleList as a + # direct child, the entire model forms one unmatched group that onloads atomically before any + # submodule code runs, so pos_embed.weight.device is CUDA by the time it is read. + # + # For leaf_level we therefore move the text encoder to the target device directly (the same + # pattern the base test already uses for the VAE) and only apply leaf_level offloading to + # the diffusers-native transformer. + if not self.test_group_offloading: + return + + def create_pipe(): + torch.manual_seed(0) + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + return pipe(**inputs)[0] + + pipe = create_pipe().to(torch_device) + output_without_group_offloading = run_forward(pipe) + + # block_level: the full text encoder becomes one group (no direct ModuleList children), so + # the atomc onload/offload is safe. + pipe = create_pipe() + for component_name in ["transformer", "text_encoder"]: + component = getattr(pipe, component_name, None) + if component is None: + continue + if hasattr(component, "enable_group_offload"): + component.enable_group_offload( + torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1 + ) + else: + apply_group_offloading( + component, + onload_device=torch.device(torch_device), + offload_type="block_level", + num_blocks_per_group=1, + ) + pipe.vae.to(torch_device) + output_with_block_level = run_forward(pipe) + + pipe = create_pipe() + pipe.transformer.enable_group_offload(torch.device(torch_device), offload_type="leaf_level") + pipe.text_encoder.to(torch_device) + pipe.vae.to(torch_device) + output_with_leaf_level = run_forward(pipe) + + if torch.is_tensor(output_without_group_offloading): + output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() + output_with_block_level = output_with_block_level.detach().cpu().numpy() + output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy() + + self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4)) + self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4)) + + @unittest.skip("Qwen3VLForConditionalGeneration does not support leaf-level group offloading") + def test_pipeline_level_group_offloading_inference(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_cpu_offload_forward_pass(self): + pass + + @unittest.skip("Qwen3VLForConditionalGeneration does not support sequential CPU offloading") + def test_sequential_offload_forward_pass_twice(self): + pass