[feat] JoyAI-JoyImage-Edit support#13444
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR! I left some initial feedbacks
| return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | ||
|
|
||
|
|
||
| class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): |
There was a problem hiding this comment.
ohh what's going on here? is this some legancy code? can we remove?
There was a problem hiding this comment.
We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.
They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.
| img_qkv = self.img_attn_qkv(img_modulated) | ||
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| img_q = self.img_attn_q_norm(img_q).to(img_v) | ||
| img_k = self.img_attn_k_norm(img_k).to(img_v) | ||
| if vis_freqs_cis is not None: | ||
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | ||
|
|
||
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | ||
| txt_qkv = self.txt_attn_qkv(txt_modulated) | ||
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | ||
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | ||
| if txt_freqs_cis is not None: | ||
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | ||
|
|
||
| q = torch.cat((img_q, txt_q), dim=1) | ||
| k = torch.cat((img_k, txt_k), dim=1) | ||
| v = torch.cat((img_v, txt_v), dim=1) | ||
|
|
||
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | ||
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
There was a problem hiding this comment.
| img_qkv = self.img_attn_qkv(img_modulated) | |
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| img_q = self.img_attn_q_norm(img_q).to(img_v) | |
| img_k = self.img_attn_k_norm(img_k).to(img_v) | |
| if vis_freqs_cis is not None: | |
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | |
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | |
| txt_qkv = self.txt_attn_qkv(txt_modulated) | |
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
| if txt_freqs_cis is not None: | |
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | |
| q = torch.cat((img_q, txt_q), dim=1) | |
| k = torch.cat((img_k, txt_k), dim=1) | |
| v = torch.cat((img_v, txt_v), dim=1) | |
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | |
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] | |
| attn_output, text_attn_output = self.attn(...) |
can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)
also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )
There was a problem hiding this comment.
Thanks for the reminder. I'll clean up this messy code.
| class ModulateX(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | ||
| super().__init__() | ||
| self.factor = factor | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| if len(x.shape) != 3: | ||
| x = x.unsqueeze(1) | ||
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
There was a problem hiding this comment.
| class ModulateX(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | |
| super().__init__() | |
| self.factor = factor | |
| def forward(self, x: torch.Tensor): | |
| if len(x.shape) != 3: | |
| x = x.unsqueeze(1) | |
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
| class ModulateDiT(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| super().__init__() | ||
| self.factor = factor | ||
| self.act = act_layer() | ||
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | ||
| nn.init.zeros_(self.linear.weight) | ||
| nn.init.zeros_(self.linear.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
There was a problem hiding this comment.
| class ModulateDiT(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| super().__init__() | |
| self.factor = factor | |
| self.act = act_layer() | |
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | |
| nn.init.zeros_(self.linear.weight) | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor): | |
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX
| head_dim = hidden_size // heads_num | ||
| mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | ||
|
|
||
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) |
There was a problem hiding this comment.
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) | |
| self.img_mod = JoyImageModulate(...) |
let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too
There was a problem hiding this comment.
Ok, I will refactor modulation and use ModulateWan
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline from the Moran232/diffusers fork + transformers 4.57.1. Process isolation needed because the fork's diffusers core registry patches cannot be vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x is incompatible with our 5.3.0 stack. Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at 1024² / 30 steps (well under the 80 GB gate). Passed. - `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call short-lived AsyncClient, split timeouts (180s edit / 60s mgmt), HTTPStatus→JoyAIError mapping. Singleton `joyai` exported. - `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and `LOAD_JOYAI` env flag. Off by default. - `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4 helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)` helper. All three `_ensure_*_ready()` helpers are now `async def` — 13 call sites updated across _dispatch_job and v1 sync handlers. IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client; validates len(image_paths)==1 (422 otherwise). Lifespan health-probes the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503 if unreachable). - `flux_manager.py`: pre-existing bug fix — _edit() hardcoded ensure_model("flux2-klein"), silently ignoring the dispatcher's `model` kwarg. Now accepts and respects `model`. Guidance_scale is now conditional on model != "flux2-klein" (Klein strips CFG, Dev uses it). - `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py` (+3 tests): 89 tests passing (was 79). - Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all updated with joyai-edit model entry, three-tenant swap diagram, latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8 changelog entry. Out-of-tree (not committed here, installed separately): /mnt/nvme-1/servers/joyai-sidecar/ (sidecar venv + sidecar.py + run.sh) ~/.config/systemd/user/joyai-sidecar.service Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit → SSE stream (phase denoising → encoding → None) → fetch WEBP result (352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap evicted LTX and reloaded it cleanly via _evict_other_tenants. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| self.args = SimpleNamespace( | ||
| enable_activation_checkpointing=enable_activation_checkpointing, | ||
| is_repa=is_repa, | ||
| repa_layer=repa_layer, | ||
| ) | ||
|
|
There was a problem hiding this comment.
| self.args = SimpleNamespace( | |
| enable_activation_checkpointing=enable_activation_checkpointing, | |
| is_repa=is_repa, | |
| repa_layer=repa_layer, | |
| ) |
I think we can use self.config here (e.g. self.config.is_repa, self.config.repa_layer, etc.) instead of needing to define a separate namespace.
There was a problem hiding this comment.
Was the repa logic removed because it is not used in inference?
| timesteps: List[int] = None, | ||
| sigmas: List[float] = None, |
There was a problem hiding this comment.
| timesteps: List[int] = None, | |
| sigmas: List[float] = None, | |
| timesteps: list[int] | None = None, | |
| sigmas: list[float] | None = None, |
nit: could we switch to Python 3.9+ style implicit type hints here and elsewhere?
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Left an initial design review :).
|
@yiyixuxu @dg845 Specifically, I refactored the attention module. However, since the weight key names in the Diffusers model are already fixed, I didn't change the actual keys in the attention part. Additionally, I will consider refactoring the image pre-processing logic, since the logic is quite complex, I directly copied it over from the training code. If you have any further suggestions, please feel free to share. Thank you so much! |
| # ---- joint attention (fused QKV, directly on the block) ---- | ||
| # image attention layers | ||
| self.img_attn_qkv = nn.Linear(dim, inner_dim * 3, bias=True) | ||
| self.img_attn_q_norm = nn.RMSNorm(attention_head_dim, eps=eps) |
There was a problem hiding this comment.
If I remember correctly, the attention sublayer used to use the custom RMSNorm module, which upcasted to FP32 during the RMS computation. Here we're using torch.nn.RMSNorm, which doesn't. Is this intentional?
|
@dg845 @yiyixuxu
Looking forward to the merge! Wishing you a Happy International Workers’ Day! Here are some scripts: Inference Scriptsimport torch
from diffusers import JoyImageEditPipeline
from diffusers.utils import load_image
pipeline = JoyImageEditPipeline.from_pretrained(
"jdopensource/JoyAI-Image-Edit-Diffusers", torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
img_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
image = load_image(img_path)
# edit
image = pipeline(
image=image,
prompt="Add wings to the astronaut.",
generator=torch.Generator("cuda").manual_seed(0),
guidance_scale=4.0,
).images[0]
image.save("edit.png")
# t2i
image = pipeline(
prompt="A toy astronaut with wings.",
generator=torch.Generator("cuda").manual_seed(0),
guidance_scale=4.0,
).images[0]
image.save("t2i.png")
# batch edit
output = pipeline(
image=[image, image],
prompt=["Add wings to the astronaut.", "Add halo to the astronaut."],
generator=torch.Generator("cuda").manual_seed(0),
guidance_scale=4.0,
)
output.images[0].save("result1.png")
output.images[1].save("result2.png") |
|
@bot /style |
|
Style fix is beginning .... View the workflow run here. |
|
@bot /style |
|
Style fix is beginning .... View the workflow run here. |
|
When I run the pytest tests/pipelines/joyimage/test_joyimage_edit.pythe following tests fail:
After debugging, my understanding is that the root cause of these test failures is that Block-level offloading does work for For now, I think we can either skip the tests or override them so that the Qwen3-VL Claude suggestion for overriding test_group_offloading_inference def test_group_offloading_inference(self):
# Qwen3VLForConditionalGeneration (the text encoder) is incompatible with leaf_level group
# offloading. Its Qwen3VLVisionModel.fast_pos_embed_interpolate reads
# `self.pos_embed.weight.device` to create intermediate tensors before the Embedding's
# pre_forward hook fires, so the intermediate tensors land on CPU while hidden_states
# (produced by the Conv3d patch_embed) land on CUDA, causing a device mismatch.
#
# block_level works correctly: since Qwen3VLForConditionalGeneration has no ModuleList as a
# direct child, the entire model forms one unmatched group that onloads atomically before any
# submodule code runs, so pos_embed.weight.device is CUDA by the time it is read.
#
# For leaf_level we therefore move the text encoder to the target device directly (the same
# pattern the base test already uses for the VAE) and only apply leaf_level offloading to
# the diffusers-native transformer.
if not self.test_group_offloading:
return
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)
# block_level: the full text encoder becomes one group (no direct ModuleList children), so
# the atomc onload/offload is safe.
pipe = create_pipe()
for component_name in ["transformer", "text_encoder"]:
component = getattr(pipe, component_name, None)
if component is None:
continue
if hasattr(component, "enable_group_offload"):
component.enable_group_offload(
torch.device(torch_device), offload_type="block_level", num_blocks_per_group=1
)
else:
apply_group_offloading(
component,
onload_device=torch.device(torch_device),
offload_type="block_level",
num_blocks_per_group=1,
)
pipe.vae.to(torch_device)
output_with_block_level = run_forward(pipe)
# leaf_level: skip the text encoder (transformers model with device-dependent tensor
# creation) and move it to the target device directly.
pipe = create_pipe()
pipe.transformer.enable_group_offload(
torch.device(torch_device), offload_type="leaf_level"
)
pipe.text_encoder.to(torch_device)
pipe.vae.to(torch_device)
output_with_leaf_level = run_forward(pipe)
if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_block_level = output_with_block_level.detach().cpu().numpy()
output_with_leaf_level = output_with_leaf_level.detach().cpu().numpy()
self.assertTrue(np.allclose(output_without_group_offloading, output_with_block_level, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_leaf_level, atol=1e-4))In the longer term (a separate PR), we can consider adding support for specifying modules that only support block-level offloading in e.g. |
|
Also, can you run |
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks!!
i left some final comments! lt's merge this soon:)
| time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype | ||
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: | ||
| timestep = timestep.to(time_embedder_dtype) | ||
| temb = self.time_embedder(timestep).type_as(encoder_hidden_states) |
There was a problem hiding this comment.
| time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype | |
| if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: | |
| timestep = timestep.to(time_embedder_dtype) | |
| temb = self.time_embedder(timestep).type_as(encoder_hidden_states) | |
| timestep = timestep.to(encoder_hidden_states.dtype) | |
| temb = self.time_embedder(timestep) |
Would this work? I know this function is copied from wan, but this is a pattern we generally want to avoid (infer dtype using weights), and i cannot recall why we did it this way with WAN (could just be an oversight)
let me know if we can simplify here a bit
There was a problem hiding this comment.
I tried the simplification but it causes a RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float at runtime.
This is because condition_embedder is listed in _skip_layerwise_casting_patterns, so its weights stay in Float32 during layerwise casting, while encoder_hidden_states is BFloat16. Casting timestep to encoder_hidden_states.dtype (bf16) then feeding it into the Float32 time_embedder linear layers triggers the mismatch.
The current pattern reads the actual weight dtype, aligns the input to it, computes in that precision, then casts the output back. This is necessary, and I think this is also why WAN does it the same way.
Error message
❯ python solve.py
Skipping import of cpp extensions due to incompatible torch version 2.8.0+cu128 for torchao version 0.16.0 Please see https://github.com/pytorch/ao/issues/2919 for more info
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 4.84it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00, 3.06it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:04<00:00, 1.34it/s]
0%| | 0/40 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/pfs/huangfeice/solve/solve.py", line 18, in <module>
image = pipeline(
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/pfs/huangfeice/diffusers/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py", line 845, in __call__
noise_pred = self.transformer(
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/pfs/huangfeice/diffusers/src/diffusers/models/transformers/transformer_joyimage.py", line 550, in forward
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/pfs/huangfeice/diffusers/src/diffusers/models/transformers/transformer_joyimage.py", line 357, in forward
temb = self.time_embedder(timestep)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/pfs/huangfeice/diffusers/src/diffusers/models/embeddings.py", line 1297, in forward
sample = self.linear_1(sample)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniforge/envs/comfyenv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float| if images is not None: | ||
| return self.encode_prompt_multiple_images( | ||
| prompt=prompt, | ||
| images=images, | ||
| device=device, | ||
| max_sequence_length=max_sequence_length, | ||
| ) |
There was a problem hiding this comment.
| if images is not None: | |
| return self.encode_prompt_multiple_images( | |
| prompt=prompt, | |
| images=images, | |
| device=device, | |
| max_sequence_length=max_sequence_length, | |
| ) |
can we move the encode_prompt_multiple_images out from encde_prompt?
so in pipeline.__call__
if processed_image is None:
... = self.encode_prompt()
else:
... = self.encode_prompt_multiple_images()additionly, let's make sure encode_prompt_multiple_images also accept pre-computed prompt_embeds/prompt_embesd_mask and expand based on num_images_per_prompt - similar to what's done in encode_prompt
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||
| max_sequence_length: int = 4096, | ||
| enable_denormalization: bool = True, | ||
| **kwargs, |
There was a problem hiding this comment.
ohh what's this for? can we remove this?
|
@claude can you do a final review too? |
|
Claude finished @yiyixuxu's task in 3m 24s —— View job Final Review of JoyAI-JoyImage-Edit PR
The PR has improved significantly since the last round. Many of the earlier issues (einops removal, attention refactor, gradient checkpointing, assert→ValueError, etc.) have been addressed. Below are the remaining issues I found, organized by severity. I've also noted which items align with outstanding reviewer comments from @yiyixuxu and @dg845 that still need resolution. Must Fix1. _keep_in_fp32_modules = [
"time_embedder",
"scale_shift_table", # ← does not exist in this model
"norm1",
"norm2",
"norm3", # ← does not exist in this model
]
2. "JoyImageEditPipeline",
"Kandinsky3Img2ImgPipeline",
"JoyImageEditPipelineOutput", # ← out of alphabetical order
3. enable_tiling (`bool`, *optional*, defaults to `False`):
Enable tiled VAE decoding to reduce peak memory usage.This parameter does not exist in 4. yiyixuxu flagged this. Should Fix (Outstanding reviewer comments)5. Currently # In __call__:
if processed_image is None:
... = self.encode_prompt(...)
else:
... = self.encode_prompt_multiple_images(...)Additionally, 6. Timestep dtype inference via weight introspection — time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)Per models.md gotcha #4: "don't cast activations by reading a weight's dtype — the stored weight dtype isn't the compute dtype under gguf / quantized loading." yiyixuxu suggested simplifying to: timestep = timestep.to(encoder_hidden_states.dtype)
temb = self.time_embedder(timestep)7. The attention sublayer QK norms use Nice to Fix8. Two forward passes instead of batched CFG — The denoising loop makes two separate transformer forward passes (conditional + unconditional) instead of batching them together. This is ~2x slower than the standard diffusers pattern: latent_model_input = torch.cat([latents] * 2)
prompt_concat = torch.cat([negative_prompt_embeds, prompt_embeds])
noise_pred = self.transformer(latent_model_input, ..., prompt_concat, ...)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)I understand this may be intentional due to the multi-item latent structure, but it's worth considering for performance. 9. Always-on guidance rescaling — cond_norm = torch.norm(noise_pred, dim=2, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=2, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm.clamp_min(1e-6))This norm-based guidance rescaling is always applied during CFG. Most diffusers pipelines make this optional via a 10. Mixed The pipeline mixes old-style typing ( 11.
12.
if basesize not in BUCKETS:
raise ValueError(f"Unsupported basesize {basesize}. Supported: {list(BUCKETS.keys())}")13.
Style / CI14. Tests15. pytest.skip("Tolerance requirements too high for meaningful test")This unconditionally skips a core test. If there are tolerance issues, investigate the root cause or use 16. Offloading test failures — dg845 identified that SummaryThe architecture is clean — The main outstanding items are yiyixuxu's latest review comments (#5, #6, #4), the stale |
|
ohh calude did a good job here, can you look into fixing the |
|
@claude are you able to do a dead code analysis since it's a new model here? |
|
Claude finished @yiyixuxu's task in 2m 9s —— View job Dead Code Analysis for JoyImage-Edit Model
Why this wasn't in the previous reviewsThe review rules in Call path tracePipeline noise_pred = self.transformer(
hidden_states=latent_model_input, # (B, num_items, C, T, H, W) — 6D when editing
timestep=t_expand,
encoder_hidden_states=prompt_embeds,
return_dict=False,
)[0]Only three arguments are ever passed: def forward(self, hidden_states, timestep, encoder_hidden_states=None, return_dict=True):All four parameters are used. Good — no unused forward parameters. Published config (jdopensource/JoyAI-Image-Edit-Diffusers/transformer/config.json): {
"hidden_size": 4096,
"num_attention_heads": 32,
"rope_type": "rope",
"rope_dim_list": [16, 56, 56],
"patch_size": [1, 2, 2],
"theta": 10000
}Findings (advisory — not blocking)Under the published config and the pipeline's call path, the following code appears unreachable or unused: 1. Text RoPE branch is dead code — txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,The published config has
This is likely intentional (the model was trained with 2. if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]The published config sets 3. _keep_in_fp32_modules = [
"time_embedder", # ✓ exists (inside condition_embedder.time_embedder)
"scale_shift_table", # ✗ does NOT exist anywhere in the model
"norm1", # ✓ exists (img_norm1, txt_norm1 in blocks)
"norm2", # ✓ exists (img_norm2, txt_norm2 in blocks)
"norm3", # ✗ does NOT exist anywhere in the model
]
4. num_items = 0
if is_multi_item:
num_items = hidden_states.shape[1]When 5. Pipeline: prompt_embeds, prompt_embeds_mask = self.encode_prompt(
...
images=processed_image,
...
template_type="image", # <-- this value is ignored when images is not None
)When 6. _skip_layerwise_casting_patterns = ["img_in", "condition_embedder", "norm"]The 7. if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
Methods and layers that ARE used (confirming no false positives)
All initialized layers and helper methods are reachable from the pipeline call path. No dead layers or methods found. SummaryThe model is clean — all initialized layers and methods are used. The dead code is limited to:
None of these are bugs — they're defensive/copied code that happens to be unreachable under the published config. The authors may have configs or future models that exercise these paths. |
… support prompt_embeds/prompt_embesd_mask/num_images_per_prompt in edit mode
|
@dg845 @yiyixuxu Refactor
Test updates
Not changed
|
Description
We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.
GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430
Model Overview
JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).
Kye Features
Image edit examples