Skip to content

[feat] JoyAI-JoyImage-Edit support#13444

Open
Moran232 wants to merge 43 commits intohuggingface:mainfrom
Moran232:joyimage_edit
Open

[feat] JoyAI-JoyImage-Edit support#13444
Moran232 wants to merge 43 commits intohuggingface:mainfrom
Moran232:joyimage_edit

Conversation

@Moran232
Copy link
Copy Markdown

@Moran232 Moran232 commented Apr 10, 2026

Description

We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.

GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430

Model Overview

JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).

Kye Features

  • Advanced Text Rendering Showcase: JoyAI-Image is optimized for challenging text-heavy scenarios, including multi-panel comics, dense multi-line text, multilingual typography, long-form layouts, real-world scene text, and handwritten styles.
  • Multi-view Generation and Spatial Editing Showcase: JoyAI-Image showcases a spatially grounded generation and editing pipeline that supports multi-view generation, geometry-aware transformations, camera control, object rotation, and precise location-specific object editing. Across these settings, it preserves scene content, structure, and visual consistency while following viewpoint-sensitive instructions more accurately.
  • Spatial Editing for Spatial Reasoning Showcase: JoyAI-Image poses high-fidelity spatial editing, serving as a powerful catalyst for enhancing spatial reasoning. Compared with Qwen-Image-Edit and Nano Banana Pro, JoyAI-Image-Edit synthesizes the most diagnostic viewpoints by faithfully executing camera motions. These high-fidelity novel views effectively disambiguate complex spatial relations, providing clearer visual evidence for downstream reasoning.

Image edit examples

spatial-editing-showcase

@github-actions github-actions Bot added models pipelines size/L PR with diff > 200 LOC labels Apr 10, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left some initial feedbacks

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's going on here? is this some legancy code? can we remove?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Comment on lines +371 to +391
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)

txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)

q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)

attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
attn_output, text_attn_output = self.attn(...)

can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)

also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'll clean up this messy code.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in d397b68

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +242 to +250
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor

def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]

Comment on lines +214 to +225
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)

is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in f557113

head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
self.img_mod = JoyImageModulate(...)

let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will refactor modulation and use ModulateWan

tacos8me added a commit to tacos8me/taco-desktop-backend that referenced this pull request Apr 11, 2026
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a
separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline
from the Moran232/diffusers fork + transformers 4.57.1. Process isolation
needed because the fork's diffusers core registry patches cannot be
vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x
is incompatible with our 5.3.0 stack.

Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at
1024² / 30 steps (well under the 80 GB gate). Passed.

- `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call
  short-lived AsyncClient, split timeouts (180s edit / 60s mgmt),
  HTTPStatus→JoyAIError mapping. Singleton `joyai` exported.
- `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and
  `LOAD_JOYAI` env flag. Off by default.
- `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4
  helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)`
  helper. All three `_ensure_*_ready()` helpers are now `async def` —
  13 call sites updated across _dispatch_job and v1 sync handlers.
  IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client;
  validates len(image_paths)==1 (422 otherwise). Lifespan health-probes
  the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503
  if unreachable).
- `flux_manager.py`: pre-existing bug fix — _edit() hardcoded
  ensure_model("flux2-klein"), silently ignoring the dispatcher's
  `model` kwarg. Now accepts and respects `model`. Guidance_scale
  is now conditional on model != "flux2-klein" (Klein strips CFG,
  Dev uses it).
- `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py`
  (+3 tests): 89 tests passing (was 79).
- Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all
  updated with joyai-edit model entry, three-tenant swap diagram,
  latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8
  changelog entry.

Out-of-tree (not committed here, installed separately):
  /mnt/nvme-1/servers/joyai-sidecar/     (sidecar venv + sidecar.py + run.sh)
  ~/.config/systemd/user/joyai-sidecar.service

Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit →
SSE stream (phase denoising → encoding → None) → fetch WEBP result
(352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap
evicted LTX and reloaded it cleanly via _evict_other_tenants.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +454 to +459
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.

Copy link
Copy Markdown
Author

@Moran232 Moran232 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete this repaxxx, see f557113

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the repa logic removed because it is not used in inference?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment on lines +900 to +901
timesteps: List[int] = None,
sigmas: List[float] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps: List[int] = None,
sigmas: List[float] = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,

nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@Moran232
Copy link
Copy Markdown
Author

@yiyixuxu @dg845
Thank you very much for your valuable feedback. I've made some modifications. See my latest commits.

Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code.

If you have any further suggestions, please feel free to share. Thank you so much!

# ---- joint attention (fused QKV, directly on the block) ----
# image attention layers
self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True)
self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread setup.py Outdated
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 30, 2026
@feice-huang
Copy link
Copy Markdown

feice-huang commented Apr 30, 2026

@dg845 @yiyixuxu
Hello! I have completed a new round of revisions. There are two major updates:

  • Uploade a smaller checkpoint to huangfeice/tiny-random-Qwen3VLForConditionalGeneration and update the config in test. 9d9ef52

  • Refactor the pipeline, prompts can now be automatically wrapped. Additionally, we enabled t2i (Text-to-Image) in pipeline. The model will now perform image generation tasks when no image is provided. (JoyImageEdit is primarily developed for editing tasks, yet it also possesses inherent image generation capabilities.) 87b5383

Looking forward to the merge! Wishing you a Happy International Workers’ Day!

Here are some scripts:

Inference Scripts
import torch

from diffusers import JoyImageEditPipeline
from diffusers.utils import load_image

pipeline = JoyImageEditPipeline.from_pretrained(
    "jdopensource/JoyAI-Image-Edit-Diffusers", torch_dtype=torch.bfloat16
)
pipeline.to("cuda")

img_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image = load_image(img_path)

# edit
image = pipeline(
    image=image,
    prompt="Add wings to the astronaut.",
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
).images[0]

image.save("edit.png")

# t2i
image = pipeline(
    prompt="A toy astronaut with wings.",
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
).images[0]

image.save("t2i.png")

# batch edit
output = pipeline(
    image=[image, image],
    prompt=["Add wings to the astronaut.", "Add halo to the astronaut."],
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
)

output.images[0].save("result1.png")
output.images[1].save("result2.png")

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 30, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 30, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 1, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

Style fix is beginning .... View the workflow run here.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 1, 2026

When I run the JoyImageEditPipeline tests locally with

pytest tests/pipelines/joyimage/test_joyimage_edit.py

the following tests fail:

  • test_group_offloading_inference
  • test_pipeline_level_group_offloading_inference
  • test_sequential_cpu_offload_forward_pass
  • test_sequential_offload_forward_pass_twice

After debugging, my understanding is that the root cause of these test failures is that Qwen3VLForConditionalGeneration doesn't support leaf-level offloading (specifically, Qwen3VLVisionModel.fast_pos_embed_interpolate breaks the leaf-level offloading assumption by caching the device before the pre_forward offloading hook fires, then using it for device placement afterwards, so the pos_embeds activation it produces ends up on the cached offload device rather than the onload device.)

Block-level offloading does work for Qwen3VLForConditionalGeneration because the Qwen3VLModel submodule is treated as a single atomic group, which avoids the "device read and placement between groups" device-mismatch issue.

For now, I think we can either skip the tests or override them so that the Qwen3-VL text_encoder is excluded when testing leaf-level group offloading.

Claude suggestion for overriding test_group_offloading_inference
    def test_group_offloading_inference(self):
        # Qwen3VLForConditionalGeneration (the text encoder) is incompatible with leaf_level group
        # offloading. Its Qwen3VLVisionModel.fast_pos_embed_interpolate reads
        # `self.pos_embed.weight.device` to create intermediate tensors before the Embedding's
        # pre_forward hook fires, so the intermediate tensors land on CPU while hidden_states
        # (produced by the Conv3d patch_embed) land on CUDA, causing a device mismatch.
        #
        # block_level works correctly: since Qwen3VLForConditionalGeneration has no ModuleList as a
        # direct child, the entire model forms one unmatched group that onloads atomically before any
        # submodule code runs, so pos_embed.weight.device is CUDA by the time it is read.
        #
        # For leaf_level we therefore move the text encoder to the target device directly (the same
        # pattern the base test already uses for the VAE) and only apply leaf_level offloading to
        # the diffusers-native transformer.
        if not self.test_group_offloading:
            return

        def create_pipe():
            torch.manual_seed(0)
            components = self.get_dummy_components()
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            return pipe

        def run_forward(pipe):
            torch.manual_seed(0)
            inputs = self.get_dummy_inputs(torch_device)
            return pipe(**inputs)[0]

        pipe = create_pipe().to(torch_device)
        output_without_group_offloading = run_forward(pipe)

        # block_level: the full text encoder becomes one group (no direct ModuleList children), so
        # the atomc onload/offload is safe.
        pipe = create_pipe()
        for component_name in ["transformer", "text_encoder"]:
            component = getattr(pipe, component_name, None)
            if component is None:
                continue
            if hasattr(component, "enable_group_offload"):
                component.enable_group_offload(
                    torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1
                )
            else:
                apply_group_offloading(
                    component,
                    onload_device=torch.device(torch_device),
                    offload_type="block_level",
                    num_blocks_per_group=1,
                )
        pipe.vae.to(torch_device)
        output_with_block_level = run_forward(pipe)

        # leaf_level: skip the text encoder (transformers model with device-dependent tensor
        # creation) and move it to the target device directly.
        pipe = create_pipe()
        pipe.transformer.enable_group_offload(
            torch.device(torch_device), offload_type="leaf_level"
        )
        pipe.text_encoder.to(torch_device)
        pipe.vae.to(torch_device)
        output_with_leaf_level = run_forward(pipe)

        if torch.is_tensor(output_without_group_offloading):
            output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
            output_with_block_level = output_with_block_level.detach().cpu().numpy()
            output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy()

        self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4))
        self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4))

In the longer term (a separate PR), we can consider adding support for specifying modules that only support block-level offloading in e.g. DiffusionPipeline.enable_group_offload/apply_group_offloading, for example with something like a per_module_offload_type dict argument. CC @sayakpaul

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 1, 2026

Also, can you run make style and make quality to fix the code style so that the CI is green? The style bot doesn't seem to be working so I don't think I can do it from my side.

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!!
i left some final comments! lt's merge this soon:)

Comment on lines +356 to +359
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep = timestep.to(encoder_hidden_states.dtype)
temb = self.time_embedder(timestep)

Would this work? I know this function is copied from wan, but this is a pattern we generally want to avoid (infer dtype using weights), and i cannot recall why we did it this way with WAN (could just be an oversight)

let me know if we can simplify here a bit

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the simplification but it causes a RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float at runtime.

This is because condition_embedder is listed in _skip_layerwise_casting_patterns, so its weights stay in Float32 during layerwise casting, while encoder_hidden_states is BFloat16. Casting timestep to encoder_hidden_states.dtype (bf16) then feeding it into the Float32 time_embedder linear layers triggers the mismatch.

The current pattern reads the actual weight dtype, aligns the input to it, computes in that precision, then casts the output back. This is necessary, and I think this is also why WAN does it the same way.

Error message
python solve.py
Skipping import of cpp extensions due to incompatible torch version 2.8.0+cu128 for torchao version 0.16.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  4.84it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.06it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:04<00:00,  1.34it/s]
  0%|                                                                                                                                                 | 0/40 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/pfs/huangfeice/solve/solve.py", line 18, in <module>
    image = pipeline(
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/pfs/huangfeice/diffusers/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py", line 845, in __call__
    noise_pred = self.transformer(
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/pfs/huangfeice/diffusers/src/diffusers/models/transformers/transformer_joyimage.py", line 550, in forward
    _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/pfs/huangfeice/diffusers/src/diffusers/models/transformers/transformer_joyimage.py", line 357, in forward
    temb = self.time_embedder(timestep)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/pfs/huangfeice/diffusers/src/diffusers/models/embeddings.py", line 1297, in forward
    sample = self.linear_1(sample)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

Comment on lines +386 to +392
if images is not None:
return self.encode_prompt_multiple_images(
prompt=prompt,
images=images,
device=device,
max_sequence_length=max_sequence_length,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if images is not None:
return self.encode_prompt_multiple_images(
prompt=prompt,
images=images,
device=device,
max_sequence_length=max_sequence_length,
)

can we move the encode_prompt_multiple_images out from encde_prompt?

so in pipeline.__call__

if processed_image is None:
    ...  = self.encode_prompt()
else:
   ...  = self.encode_prompt_multiple_images()

additionly, let's make sure encode_prompt_multiple_images also accept pre-computed prompt_embeds/prompt_embesd_mask and expand based on num_images_per_prompt - similar to what's done in encode_prompt

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor in 63baf43

callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 4096,
enable_denormalization: bool = True,
**kwargs,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's this for? can we remove this?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, removed.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 1, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 1, 2026

@claude can you do a final review too?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

Claude finished @yiyixuxu's task in 3m 24s —— View job


Final Review of JoyAI-JoyImage-Edit PR

  • Read repository guidelines (CLAUDE.md, models.md, pipelines.md)
  • Review transformer model (transformer_joyimage.py)
  • Review image processor (image_processor.py)
  • Review pipeline (pipeline_joyimage_edit.py)
  • Review pipeline output and init/export files
  • Review tests
  • Check outstanding reviewer comments
  • Post final review

The PR has improved significantly since the last round. Many of the earlier issues (einops removal, attention refactor, gradient checkpointing, assert→ValueError, etc.) have been addressed. Below are the remaining issues I found, organized by severity. I've also noted which items align with outstanding reviewer comments from @yiyixuxu and @dg845 that still need resolution.


Must Fix

1. _keep_in_fp32_modules references non-existent modulestransformer_joyimage.py:381-387

_keep_in_fp32_modules = [
    "time_embedder",
    "scale_shift_table",  # ← does not exist in this model
    "norm1",
    "norm2",
    "norm3",              # ← does not exist in this model
]

scale_shift_table and norm3 are WAN concepts copied over but not present in JoyImageEditTransformer3DModel. These entries are harmless (they just match nothing) but are misleading. Update to match the actual module names in this model.

2. __init__.py export orderingsrc/diffusers/__init__.py:584-586

"JoyImageEditPipeline",
"Kandinsky3Img2ImgPipeline",
"JoyImageEditPipelineOutput",  # ← out of alphabetical order

JoyImageEditPipelineOutput should come immediately after JoyImageEditPipeline, before Kandinsky3Img2ImgPipeline. This will likely be caught by make quality.

3. enable_tiling still in docstringpipeline_joyimage_edit.py:698-699

            enable_tiling (`bool`, *optional*, defaults to `False`):
                Enable tiled VAE decoding to reduce peak memory usage.

This parameter does not exist in __call__. Remove from docstring.

4. **kwargs in __call__ silently swallows argumentspipeline_joyimage_edit.py:649

yiyixuxu flagged this. **kwargs is accepted but never used — it silently discards any mistyped keyword arguments. Remove it.


Should Fix (Outstanding reviewer comments)

5. encode_prompt_multiple_images should be called from __call__ directly — yiyixuxu's comment

Currently encode_prompt internally routes to encode_prompt_multiple_images when images is not None. yiyixuxu asked to move the routing to __call__:

# In __call__:
if processed_image is None:
    ... = self.encode_prompt(...)
else:
    ... = self.encode_prompt_multiple_images(...)

Additionally, encode_prompt_multiple_images should accept pre-computed prompt_embeds/prompt_embeds_mask and handle num_images_per_prompt expansion, similar to encode_prompt.

6. Timestep dtype inference via weight introspectiontransformer_joyimage.py:356-358, yiyixuxu's comment

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
    timestep = timestep.to(time_embedder_dtype)

Per models.md gotcha #4: "don't cast activations by reading a weight's dtype — the stored weight dtype isn't the compute dtype under gguf / quantized loading." yiyixuxu suggested simplifying to:

timestep = timestep.to(encoder_hidden_states.dtype)
temb = self.time_embedder(timestep)

7. nn.RMSNorm vs diffusers RMSNormtransformer_joyimage.py:207-208, 212-213

The attention sublayer QK norms use nn.RMSNorm which does not upcast to FP32 during computation. The diffusers RMSNorm (from normalization.py) does upcast. dg845 flagged this earlier. If the original model's training used FP32 upcast for QK norms, this could cause numerical drift. Please verify whether the original training code upcasts, and use the appropriate norm class.


Nice to Fix

8. Two forward passes instead of batched CFGpipeline_joyimage_edit.py:823-836

The denoising loop makes two separate transformer forward passes (conditional + unconditional) instead of batching them together. This is ~2x slower than the standard diffusers pattern:

latent_model_input = torch.cat([latents] * 2)
prompt_concat = torch.cat([negative_prompt_embeds, prompt_embeds])
noise_pred = self.transformer(latent_model_input, ..., prompt_concat, ...)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

I understand this may be intentional due to the multi-item latent structure, but it's worth considering for performance.

9. Always-on guidance rescalingpipeline_joyimage_edit.py:839-842

cond_norm = torch.norm(noise_pred, dim=2, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=2, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm.clamp_min(1e-6))

This norm-based guidance rescaling is always applied during CFG. Most diffusers pipelines make this optional via a guidance_rescale parameter. If this is an intentional design choice for JoyImage, a brief comment explaining why would help readers.

10. Mixed typing stylepipeline_joyimage_edit.py:2, 620-649

The pipeline mixes old-style typing (Dict, List, Optional, Union) and modern style (str | list[str], int | None). Similarly, the transformer uses from typing import Tuple instead of tuple. Should be consistent — prefer Python 3.9+ builtins per dg845's earlier suggestion.

11. Tuple in transformertransformer_joyimage.py:17

from typing import Tuple — replace all Tuple[...] with tuple[...] for Python 3.9+ consistency.

12. image_processor.py bucket validationimage_processor.py:57-63

find_best_bucket will KeyError if basesize != 1024. Add a check:

if basesize not in BUCKETS:
    raise ValueError(f"Unsupported basesize {basesize}. Supported: {list(BUCKETS.keys())}")

13. @register_to_config on VaeImageProcessor subclassimage_processor.py:91

VaeImageProcessor.__init__ already uses register_to_config. Double-decorating in the subclass may cause issues with config serialization. Verify this works correctly with from_config.


Style / CI

14. make style and make quality needed — dg845 asked for this. The style bot doesn't seem to be working. The authors should run make style && make quality && make fix-copies locally and push the result.


Tests

15. test_from_save_pretrained_dtype_inference unconditionally skippedtest_models_transformer_joyimage.py:90-91

pytest.skip("Tolerance requirements too high for meaningful test")

This unconditionally skips a core test. If there are tolerance issues, investigate the root cause or use @pytest.mark.xfail with a reason so it's tracked. pytest.skip hides the problem entirely.

16. Offloading test failures — dg845 identified that test_group_offloading_inference, test_pipeline_level_group_offloading_inference, test_sequential_cpu_offload_forward_pass, and test_sequential_offload_forward_pass_twice fail due to Qwen3VL incompatibility with leaf-level offloading. These need to be either overridden (as dg845 suggested) or skipped with clear documentation.


Summary

The architecture is clean — JoyImageAttention + JoyImageAttnProcessor with dispatch_attention_fn follows the recommended pattern well. Gradient checkpointing, _repeated_blocks, and _no_split_modules are all properly set. The _get_last_decoder_hidden_states hook for transformers 5.x compat is a pragmatic solution.

The main outstanding items are yiyixuxu's latest review comments (#5, #6, #4), the stale _keep_in_fp32_modules entries (#1), and running make style && make quality (#14). Once those are addressed, this should be ready to merge.


@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 1, 2026

ohh calude did a good job here, can you look into fixing the #1 and #3 listed in the claude review?

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented May 1, 2026

@claude are you able to do a dead code analysis since it's a new model here?
(also, can you let me know why you did not include this in the review already? we listed in the review-fules.md)

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

Claude finished @yiyixuxu's task in 2m 9s —— View job


Dead Code Analysis for JoyImage-Edit Model

  • Read review rules to understand the dead code analysis requirements
  • Read the pipeline __call__ and trace all calls into the transformer
  • Read the transformer model __init__ and forward
  • Read the published model config (jdopensource/JoyAI-Image-Edit-Diffusers)
  • Cross-reference: identify unused parameters, methods, and unreachable code paths
  • Explain why previous reviews missed this analysis
  • Post final dead code analysis report

Why this wasn't in the previous reviews

The review rules in .ai/review-rules.md explicitly require a dead code analysis section for PRs that add new models ("Trace the call path... Check the default model config... Flag unused parameters and methods"). My two earlier reviews focused on correctness, API design, and convention issues but failed to follow the dead-code-analysis checklist. That was an oversight on my part — the rule was there, I just didn't execute it. Sorry about that.


Call path trace

Pipeline __call__ (lines 619–880) calls into the transformer at line 823:

noise_pred = self.transformer(
    hidden_states=latent_model_input,   # (B, num_items, C, T, H, W) — 6D when editing
    timestep=t_expand,
    encoder_hidden_states=prompt_embeds,
    return_dict=False,
)[0]

Only three arguments are ever passed: hidden_states, timestep, encoder_hidden_states, and return_dict. The transformer's forward signature is:

def forward(self, hidden_states, timestep, encoder_hidden_states=None, return_dict=True):

All four parameters are used. Good — no unused forward parameters.

Published config (jdopensource/JoyAI-Image-Edit-Diffusers/transformer/config.json):

{
  "hidden_size": 4096,
  "num_attention_heads": 32,
  "rope_type": "rope",
  "rope_dim_list": [16, 56, 56],
  "patch_size": [1, 2, 2],
  "theta": 10000
}

Findings (advisory — not blocking)

Under the published config and the pipeline's call path, the following code appears unreachable or unused:

1. Text RoPE branch is dead codetransformer_joyimage.py:561

txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,

The published config has rope_type = "rope", so txt_rope_size is always None. This means:

  • get_rotary_pos_embed always takes the early return at line 490–491 (if txt_rope_size is None: return vis_freqs, None)
  • The entire text-RoPE computation (lines 494–501) is never executed
  • In JoyImageAttnProcessor.__call__, the text RoPE branch at line 143–144 (if txt_freqs is not None) is never taken — text tokens get no RoPE

This is likely intentional (the model was trained with rope_type="rope"), but the "mrope" code path is untested dead code.

2. rope_dim_list is None fallback is unreachabletransformer_joyimage.py:466-467

if rope_dim_list is None:
    rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]

The published config sets rope_dim_list = [16, 56, 56], and the __init__ default is also [16, 56, 56]. There is no config where rope_dim_list would be None. This fallback is defensive but unreachable.

3. _keep_in_fp32_modules references non-existent modulestransformer_joyimage.py:381-387

_keep_in_fp32_modules = [
    "time_embedder",       # ✓ exists (inside condition_embedder.time_embedder)
    "scale_shift_table",   # ✗ does NOT exist anywhere in the model
    "norm1",               # ✓ exists (img_norm1, txt_norm1 in blocks)
    "norm2",               # ✓ exists (img_norm2, txt_norm2 in blocks)
    "norm3",               # ✗ does NOT exist anywhere in the model
]

scale_shift_table and norm3 are WAN model concepts that were copied but don't exist in JoyImageEditTransformer3DModel. These entries silently match nothing — harmless but misleading.

4. num_items = 0 initialization — semantically misleadingtransformer_joyimage.py:532

num_items = 0
if is_multi_item:
    num_items = hidden_states.shape[1]

When is_multi_item is False, num_items remains 0. The un-multi-item block at line 581 (if is_multi_item) won't execute in that case, so num_items = 0 is never used meaningfully. However, if is_multi_item were accidentally set while num_items = 0, it would cause a silent reshape bug. Consider leaving num_items unset until needed.

5. Pipeline: template_type="image" passed to encode_prompt is overridden when images are presentpipeline_joyimage_edit.py:757

prompt_embeds, prompt_embeds_mask = self.encode_prompt(
    ...
    images=processed_image,
    ...
    template_type="image",   # <-- this value is ignored when images is not None
)

When images is not None, encode_prompt immediately delegates to encode_prompt_multiple_images (line 386-392), which hard-codes template_type="multiple_images". The template_type="image" argument passed from __call__ is silently discarded in the editing path.

6. _skip_layerwise_casting_patterns includes "condition_embedder" indirectly via "norm"transformer_joyimage.py:378

_skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"]

The "norm" pattern matches norm_out, img_norm1, img_norm2, txt_norm1, txt_norm2, and also the RMSNorm QK norms inside attention (img_attn_q_norm, etc.). This is a very broad match. It also matches condition_embedder which contains the time embedder — this is likely intentional but means most of the model is excluded from layerwise casting.

7. vec.shape[-1] > self.hidden_size check in forwardtransformer_joyimage.py:553-554

if vec.shape[-1] > self.hidden_size:
    vec = vec.unflatten(1, (6, -1))

JoyImageTimeTextImageEmbedding.forward returns temb from time_proj which has output dim hidden_size * 6 (set at init: time_proj_dim=hidden_size * 6). So vec.shape[-1] is always hidden_size * 6, which is always > self.hidden_size. This branch is always taken — the else path (where vec stays 1D) is dead code. The check could be removed and replaced with an unconditional unflatten.


Methods and layers that ARE used (confirming no false positives)

Component Called from Status
img_in (Conv3d) forward:549 ✓ Used
condition_embedder forward:552 ✓ Used
double_blocks forward:565-574 ✓ Used
norm_out + proj_out forward:577 ✓ Used
JoyImageModulate (img_mod, txt_mod) JoyImageTransformerBlock.forward:295,302 ✓ Used
JoyImageAttention (attn) JoyImageTransformerBlock.forward:311 ✓ Used
FeedForward (img_mlp, txt_mlp) JoyImageTransformerBlock.forward:325-326 ✓ Used
get_rotary_pos_embed forward:559 ✓ Used (vis path)
unpatchify forward:578 ✓ Used
_apply_rotary_emb JoyImageAttnProcessor:142-144 ✓ Used (vis path)

All initialized layers and helper methods are reachable from the pipeline call path. No dead layers or methods found.


Summary

The model is clean — all initialized layers and methods are used. The dead code is limited to:

None of these are bugs — they're defensive/copied code that happens to be unreachable under the published config. The authors may have configs or future models that exercise these paths.


@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 2, 2026
… support prompt_embeds/prompt_embesd_mask/num_images_per_prompt in edit mode
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels May 2, 2026
@feice-huang
Copy link
Copy Markdown

feice-huang commented May 2, 2026

@dg845 @yiyixuxu
Thanks for the suggestions! Updates:

Refactor

  • Separate encode_prompt_multiple_images from encode_prompt. Dispatching is now handled in __call__ based on whether processed_image is provided, instead of branching inside encode_prompt.
  • Add prompt_embeds, prompt_embeds_mask, and num_images_per_prompt support to encode_prompt_multiple_images so it matches the encode_prompt interface.

Test updates

  • Skipp test_group_offloading_inference, test_pipeline_level_group_offloading_inference, test_sequential_cpu_offload_forward_pass, and test_sequential_offload_forward_pass_twice.

Not changed

  • We keep the current time_embedder dtype inference pattern (next(iter(...)).dtype). We tested simplifying it, but that causes dtype mismatch errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for JoyAI-Image-Edit

5 participants