Skip to content

[feat] JoyAI-JoyImage-Edit support#13444

Open
Moran232 wants to merge 41 commits intohuggingface:mainfrom
Moran232:joyimage_edit
Open

[feat] JoyAI-JoyImage-Edit support#13444
Moran232 wants to merge 41 commits intohuggingface:mainfrom
Moran232:joyimage_edit

Conversation

@Moran232
Copy link
Copy Markdown

@Moran232 Moran232 commented Apr 10, 2026

Description

We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.

GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430

Model Overview

JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).

Kye Features

  • Advanced Text Rendering Showcase: JoyAI-Image is optimized for challenging text-heavy scenarios, including multi-panel comics, dense multi-line text, multilingual typography, long-form layouts, real-world scene text, and handwritten styles.
  • Multi-view Generation and Spatial Editing Showcase: JoyAI-Image showcases a spatially grounded generation and editing pipeline that supports multi-view generation, geometry-aware transformations, camera control, object rotation, and precise location-specific object editing. Across these settings, it preserves scene content, structure, and visual consistency while following viewpoint-sensitive instructions more accurately.
  • Spatial Editing for Spatial Reasoning Showcase: JoyAI-Image poses high-fidelity spatial editing, serving as a powerful catalyst for enhancing spatial reasoning. Compared with Qwen-Image-Edit and Nano Banana Pro, JoyAI-Image-Edit synthesizes the most diagnostic viewpoints by faithfully executing camera motions. These high-fidelity novel views effectively disambiguate complex spatial relations, providing clearer visual evidence for downstream reasoning.

Image edit examples

spatial-editing-showcase

@github-actions github-actions Bot added models pipelines size/L PR with diff > 200 LOC labels Apr 10, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left some initial feedbacks

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's going on here? is this some legancy code? can we remove?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Comment on lines +371 to +391
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)

txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)

q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)

attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
attn_output, text_attn_output = self.attn(...)

can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)

also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'll clean up this messy code.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in d397b68

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +242 to +250
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor

def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]

Comment on lines +214 to +225
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)

is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix in f557113

head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
self.img_mod = JoyImageModulate(...)

let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will refactor modulation and use ModulateWan

tacos8me added a commit to tacos8me/taco-desktop-backend that referenced this pull request Apr 11, 2026
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a
separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline
from the Moran232/diffusers fork + transformers 4.57.1. Process isolation
needed because the fork's diffusers core registry patches cannot be
vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x
is incompatible with our 5.3.0 stack.

Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at
1024² / 30 steps (well under the 80 GB gate). Passed.

- `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call
  short-lived AsyncClient, split timeouts (180s edit / 60s mgmt),
  HTTPStatus→JoyAIError mapping. Singleton `joyai` exported.
- `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and
  `LOAD_JOYAI` env flag. Off by default.
- `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4
  helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)`
  helper. All three `_ensure_*_ready()` helpers are now `async def` —
  13 call sites updated across _dispatch_job and v1 sync handlers.
  IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client;
  validates len(image_paths)==1 (422 otherwise). Lifespan health-probes
  the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503
  if unreachable).
- `flux_manager.py`: pre-existing bug fix — _edit() hardcoded
  ensure_model("flux2-klein"), silently ignoring the dispatcher's
  `model` kwarg. Now accepts and respects `model`. Guidance_scale
  is now conditional on model != "flux2-klein" (Klein strips CFG,
  Dev uses it).
- `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py`
  (+3 tests): 89 tests passing (was 79).
- Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all
  updated with joyai-edit model entry, three-tenant swap diagram,
  latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8
  changelog entry.

Out-of-tree (not committed here, installed separately):
  /mnt/nvme-1/servers/joyai-sidecar/     (sidecar venv + sidecar.py + run.sh)
  ~/.config/systemd/user/joyai-sidecar.service

Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit →
SSE stream (phase denoising → encoding → None) → fetch WEBP result
(352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap
evicted LTX and reloaded it cleanly via _evict_other_tenants.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment on lines +454 to +459
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.args = SimpleNamespace(
enable_activation_checkpointing=enable_activation_checkpointing,
is_repa=is_repa,
repa_layer=repa_layer,
)

I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.

Copy link
Copy Markdown
Author

@Moran232 Moran232 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete this repaxxx, see f557113

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the repa logic removed because it is not used in inference?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment on lines +900 to +901
timesteps: List[int] = None,
sigmas: List[float] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
timesteps: List[int] = None,
sigmas: List[float] = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,

nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?

Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Left an initial design review :).

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
@Moran232
Copy link
Copy Markdown
Author

@yiyixuxu @dg845
Thank you very much for your valuable feedback. I've made some modifications. See my latest commits.

Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code.

If you have any further suggestions, please feel free to share. Thank you so much!

# ---- joint attention (fused QKV, directly on the block) ----
# image attention layers
self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True)
self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?

Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
Comment thread src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py Outdated
Comment thread setup.py Outdated
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 30, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 30, 2026
Comment thread src/diffusers/models/transformers/transformer_joyimage.py Outdated
if negative_prompt is None and negative_prompt_embeds is None:
if num_items <= 1:
negative_prompt = ["<|im_start|>user\n<|im_end|>\n"] * batch_size
else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: my understanding is that wrapping the edit instructions (e.g. "Add wings to the astronaut.") with the Qwen3-VL template is important for sample quality, as seen in #13444 (comment). So I think it would be more user-friendly to automatically wrap the prompt with the template inside the pipeline like we do for negative_prompt here.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix in 87b5383


return prompt_embeds, prompt_embeds_mask

def encode_prompt(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: I think the way encode_prompt is currently implemented is confusing as the code splits into two paths (the main encode_prompt logic and encode_prompt_multiple_images) which partially do the same thing. I think it would be more to clear to refactor encode_prompt to something like this:

    def encode_prompt(...) -> tuple[torch.Tensor, torch.Tensor]:
        # 1. Handle inputs
        device = device or self._execution_device
        prompt = [prompt] if isinstance(prompt, str) else prompt
        batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
        has_image_conditions = images is not None

        # 2. Generate prompt embeddings if necessary using Qwen3VL tokenizer/processor
        if prompt_embeds is None:
            template_type = "multiple_images" if has_image_conditions else "image"
            # _get_qwen_prompt_embeds is responsible for:
            #   1. Creating the final templated prompt
            #   2. Running the processor (or possibly tokenizer) to get the text encoder inputs
            #   3. Running the text encoder and getting the right Qwen3-VL hidden_states
            #   4. Any post-processing that's specific to the multiple or single image case
            prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, template_type, device)

        # 3. Post-process prompt_embeds (common logic for both cases)
        prompt_embeds = prompt_embeds[:, :max_sequence_length]
        prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]

        # Handle expanding to num_images_per_prompt in both cases
        _, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        ...
        return prompt_embeds, prompt_embeds_mask

I think ideally _get_qwen_prompt_embeds would handle both the "multiple_images" and "image" cases, but if they can't be combined cleanly we could have separate helpers for each case.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. Current implementation comes from our training code, and we have a multi-image editing model under active development that relies on this structure. We’d prefer to keep the current approach for now to avoid disrupting that work.

super().tearDown()

def get_dummy_components(self):
tiny_ckpt_id = "huangfeice/tiny-random-Qwen3VLForConditionalGeneration"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current tiny Qwen3-VL testing checkpoint huangfeice/tiny-random-Qwen3VLForConditionalGeneration has the following issues:

  1. It's quite big (~5M params, ~19 MB), which makes the pipeline tests quite heavy, so I think we should try to reduce the size of this checkpoint. It looks like most of the parameters are in the input and output embeddings (e.g. embed_tokens), so for example reducing the vocab_size should help.
  2. The checkpoint might be misconfigured: the model config defines a vision patch_size of 14, but the processor config defines a image_processor patch_size of 16. I think this mismatch is causing some tests such as test_cfg to fail.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny-random-Qwen3VLForConditionalGeneration updated on huggingface. Config fixed in 9d9ef52.

Comment thread tests/pipelines/joyimage/test_joyimage_edit.py Outdated
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactor! I think the PR is close to merge. I left a few small comments; the most important one is #13444 (comment), as this causes some pipeline tests to fail. Also, can you run make style and make quality to fix the code style, and make fix-copies to fix any dummy objects or out-of-sync copies?

CC @yiyixuxu to take a look at the forward hook in JoyImageEditPipeline._get_last_decoder_hidden_states for transformers>=5 compatibility.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 30, 2026
@feice-huang
Copy link
Copy Markdown

feice-huang commented Apr 30, 2026

@dg845 @yiyixuxu
Hello! I have completed a new round of revisions. There are two major updates:

  • Uploade a smaller checkpoint to huangfeice/tiny-random-Qwen3VLForConditionalGeneration and update the config in test. 9d9ef52

  • Refactor the pipeline, prompts can now be automatically wrapped. Additionally, we enabled t2i (Text-to-Image) in pipeline. The model will now perform image generation tasks when no image is provided. (JoyImageEdit is primarily developed for editing tasks, yet it also possesses inherent image generation capabilities.) 87b5383

Looking forward to the merge! Wishing you a Happy International Workers’ Day!

Here are some scripts:

Inference Scripts
import torch

from diffusers import JoyImageEditPipeline
from diffusers.utils import load_image

pipeline = JoyImageEditPipeline.from_pretrained(
    "jdopensource/JoyAI-Image-Edit-Diffusers", torch_dtype=torch.bfloat16
)
pipeline.to("cuda")

img_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image = load_image(img_path)

# edit
image = pipeline(
    image=image,
    prompt="Add wings to the astronaut.",
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
).images[0]

image.save("edit.png")

# t2i
image = pipeline(
    prompt="A toy astronaut with wings.",
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
).images[0]

image.save("t2i.png")

# batch edit
output = pipeline(
    image=[image, image],
    prompt=["Add wings to the astronaut.", "Add halo to the astronaut."],
    generator=torch.Generator("cuda").manual_seed(0),
    guidance_scale=4.0,
)

output.images[0].save("result1.png")
output.images[1].save("result2.png")

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 30, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 30, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 1, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 1, 2026

Style fix is beginning .... View the workflow run here.

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 1, 2026

When I run the JoyImageEditPipeline tests locally with

pytest tests/pipelines/joyimage/test_joyimage_edit.py

the following tests fail:

  • test_group_offloading_inference
  • test_pipeline_level_group_offloading_inference
  • test_sequential_cpu_offload_forward_pass
  • test_sequential_offload_forward_pass_twice

After debugging, my understanding is that the root cause of these test failures is that Qwen3VLForConditionalGeneration doesn't support leaf-level offloading (specifically, Qwen3VLVisionModel.fast_pos_embed_interpolate breaks the leaf-level offloading assumption by caching the device before the pre_forward offloading hook fires, then using it for device placement afterwards, so the pos_embeds activation it produces ends up on the cached offload device rather than the onload device.)

Block-level offloading does work for Qwen3VLForConditionalGeneration because the Qwen3VLModel submodule is treated as a single atomic group, which avoids the "device read and placement between groups" device-mismatch issue.

For now, I think we can either skip the tests or override them so that the Qwen3-VL text_encoder is excluded when testing leaf-level group offloading.

Claude suggestion for overriding test_group_offloading_inference
    def test_group_offloading_inference(self):
        # Qwen3VLForConditionalGeneration (the text encoder) is incompatible with leaf_level group
        # offloading. Its Qwen3VLVisionModel.fast_pos_embed_interpolate reads
        # `self.pos_embed.weight.device` to create intermediate tensors before the Embedding's
        # pre_forward hook fires, so the intermediate tensors land on CPU while hidden_states
        # (produced by the Conv3d patch_embed) land on CUDA, causing a device mismatch.
        #
        # block_level works correctly: since Qwen3VLForConditionalGeneration has no ModuleList as a
        # direct child, the entire model forms one unmatched group that onloads atomically before any
        # submodule code runs, so pos_embed.weight.device is CUDA by the time it is read.
        #
        # For leaf_level we therefore move the text encoder to the target device directly (the same
        # pattern the base test already uses for the VAE) and only apply leaf_level offloading to
        # the diffusers-native transformer.
        if not self.test_group_offloading:
            return

        def create_pipe():
            torch.manual_seed(0)
            components = self.get_dummy_components()
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            return pipe

        def run_forward(pipe):
            torch.manual_seed(0)
            inputs = self.get_dummy_inputs(torch_device)
            return pipe(**inputs)[0]

        pipe = create_pipe().to(torch_device)
        output_without_group_offloading = run_forward(pipe)

        # block_level: the full text encoder becomes one group (no direct ModuleList children), so
        # the atomc onload/offload is safe.
        pipe = create_pipe()
        for component_name in ["transformer", "text_encoder"]:
            component = getattr(pipe, component_name, None)
            if component is None:
                continue
            if hasattr(component, "enable_group_offload"):
                component.enable_group_offload(
                    torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1
                )
            else:
                apply_group_offloading(
                    component,
                    onload_device=torch.device(torch_device),
                    offload_type="block_level",
                    num_blocks_per_group=1,
                )
        pipe.vae.to(torch_device)
        output_with_block_level = run_forward(pipe)

        # leaf_level: skip the text encoder (transformers model with device-dependent tensor
        # creation) and move it to the target device directly.
        pipe = create_pipe()
        pipe.transformer.enable_group_offload(
            torch.device(torch_device), offload_type="leaf_level"
        )
        pipe.text_encoder.to(torch_device)
        pipe.vae.to(torch_device)
        output_with_leaf_level = run_forward(pipe)

        if torch.is_tensor(output_without_group_offloading):
            output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
            output_with_block_level = output_with_block_level.detach().cpu().numpy()
            output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy()

        self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4))
        self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4))

In the longer term (a separate PR), we can consider adding support for specifying modules that only support block-level offloading in e.g. DiffusionPipeline.enable_group_offload/apply_group_offloading, for example with something like a per_module_offload_type dict argument. CC @sayakpaul

@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented May 1, 2026

Also, can you run make style and make quality to fix the code style so that the CI is green? The style bot doesn't seem to be working so I don't think I can do it from my side.

# Internal helpers
# ------------------------------------------------------------------

def _get_last_decoder_hidden_states(self, forward_fn, **kwargs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhhh
we will double-check with the transformers team

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for JoyAI-Image-Edit

5 participants