Skip to content

longcat_image model/pipeline review #13636

@hlky

Description

@hlky

longcat_image model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search checked GitHub issues/PRs for longcat_image, LongCatImagePipeline, LongCatImageEditPipeline, LongCatImageTransformer2DModel, prompt_embeds, negative_prompt, randn_tensor, AutoPipeline, and LoRA failures. Existing overlaps are noted below.

Issue 1: prompt_embeds-only calls crash in both pipelines

Affected code:

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is None:
prompt_embeds = self._encode_prompt(prompt)

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is None:
prompt_embeds = self._encode_prompt(prompt, image)

Problem:
check_inputs() allows prompt=None with prompt_embeds, but encode_prompt() immediately does len(prompt). The text-to-image pipeline also calls prompt rewriting before encoding unless users manually set enable_prompt_rewrite=False.

Impact:
Users cannot use precomputed prompt embeddings, despite the public prompt_embeds parameters.

Reproduction:

import torch
from diffusers import LongCatImagePipeline

pipe = LongCatImagePipeline.__new__(LongCatImagePipeline)
pipe.encode_prompt(prompt=None, prompt_embeds=torch.zeros(1, 2, 8))
# TypeError: object of type 'NoneType' has no len()

Relevant precedent:
QwenImagePipeline.encode_prompt derives batch size from prompt_embeds when prompt is absent.

Suggested fix:

if prompt_embeds is not None and prompt is None:
    batch_size = prompt_embeds.shape[0]
else:
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

Issue 2: Batched prompts break CFG when negative_prompt is omitted

Affected code:

negative_prompt = "" if negative_prompt is None else negative_prompt
(prompt_embeds, text_ids) = self.encode_prompt(
prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt
)
if self.do_classifier_free_guidance:
(negative_prompt_embeds, negative_text_ids) = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
)

Problem:
For prompt=["a", "b"], the pipeline sets negative_prompt = "", so unconditional embeddings have batch size 1 while latents and conditional embeddings have batch size 2.

Impact:
Batched text-to-image generation with default CFG can fail or produce invalid conditioning.

Reproduction:

prompt = ["a", "b"]
negative_prompt = ""
print(len(prompt), 1 if isinstance(negative_prompt, str) else len(negative_prompt))
# 2 1

Relevant precedent:
Flux/Qwen-style pipelines normalize a scalar negative prompt to match prompt batch size.

Suggested fix:

if negative_prompt is None:
    negative_prompt = [""] * batch_size if batch_size > 1 else ""
elif isinstance(negative_prompt, str) and batch_size > 1:
    negative_prompt = [negative_prompt] * batch_size

Issue 3: Text-to-image noise is generated in default float32, then cast

Affected code:

latents = randn_tensor(shape, generator=generator, device=device)
latents = latents.to(dtype=dtype)

Problem:
randn_tensor() is called without dtype=dtype, then cast afterward. The edit pipeline already passes dtype.

Impact:
bf16/fp16 inference gets different initial noise than a direct low-precision draw, which can hurt parity and wastes memory.

Reproduction:

import torch
import diffusers.pipelines.longcat_image.pipeline_longcat_image as mod
from diffusers import LongCatImagePipeline

seen = {}
def spy(shape, **kwargs):
    seen["dtype"] = kwargs.get("dtype")
    return torch.zeros(shape)

mod.randn_tensor = spy
pipe = LongCatImagePipeline.__new__(LongCatImagePipeline)
pipe.vae_scale_factor = 8
pipe.tokenizer_max_length = 512
pipe.prepare_latents(1, 16, 32, 32, torch.bfloat16, "cpu", None)
print(seen["dtype"])
# None

Relevant precedent:
FluxPipeline.prepare_latents passes dtype=dtype.

Suggested fix:

latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

Issue 4: Edit prompt truncation warning crashes

Affected code:

if len(all_tokens) > self.tokenizer_max_length:
logger.warning(
"Your input was truncated because `max_sequence_length` is set to "
f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}"
)
all_tokens = all_tokens[: self.tokenizer_max_length]

Problem:
The warning formats len(len(all_tokens)), which raises when the truncation branch is reached.

Impact:
Long edit prompts crash instead of being truncated with a warning.

Reproduction:

all_tokens = [1] * 513
print(len(len(all_tokens)))
# TypeError: object of type 'int' has no len()

Relevant precedent:
The sibling text-to-image pipeline uses len(all_tokens) correctly. Duplicate already exists: #13526

Suggested fix:

f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}"

Issue 5: Edit image position IDs are forced to float64

Affected code:

image_latents_ids = prepare_pos_ids(
modality_id=2,
type="image",
start=(prompt_embeds_length, prompt_embeds_length),
height=height // 2,
width=width // 2,
).to(device, dtype=torch.float64)

Problem:
image_latents_ids is created with dtype=torch.float64. The text pipeline does not do this, and the model immediately casts position IDs to float32 internally.

Impact:
This can fail or fall back on backends with poor/no float64 support, including MPS/NPU, and creates unnecessary dtype divergence.

Reproduction:

import torch
from diffusers.pipelines.longcat_image.pipeline_longcat_image_edit import prepare_pos_ids

ids = prepare_pos_ids(modality_id=2, type="image", height=2, width=2).to("cpu", dtype=torch.float64)
print(ids.dtype)
# torch.float64

Relevant precedent:
LongCatImagePipeline.prepare_latents keeps position IDs at the default dtype.

Suggested fix:

).to(device)

Issue 6: LoRA/joint attention integration is incomplete

Affected code:

class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
The pipeline for text-to-image generation.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> torch.FloatTensor | Transformer2DModelOutput:

Problem:
The docs show a LoRA badge, but both pipelines lack a LoRA loader mixin, and the transformer forward does not accept or forward joint_attention_kwargs even though blocks/processors support it.

Impact:
pipe.load_lora_weights() is unavailable and runtime attention kwargs cannot reach processors.

Reproduction:

from diffusers import LongCatImagePipeline, LongCatImageTransformer2DModel
import inspect

print(hasattr(LongCatImagePipeline, "load_lora_weights"))
print("joint_attention_kwargs" in inspect.signature(LongCatImageTransformer2DModel.forward).parameters)
# False
# False

Relevant precedent:
QwenImagePipeline inherits QwenImageLoraLoaderMixin; FluxTransformer2DModel.forward accepts joint_attention_kwargs. Duplicate LoRA issue/PR: #12859 and #12867

Suggested fix:
Add a LongCat-specific LoRA loader mixin, inherit it in both pipelines, add joint_attention_kwargs to transformer forward, and pass it through each block.

Issue 7: Transformer is missing _no_split_modules

Affected code:

class LongCatImageTransformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Longcat-Image.
"""
_supports_gradient_checkpointing = True
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]

Problem:
The model declares repeated block classes but not _no_split_modules.

Impact:
device_map/offload placement can split residual attention blocks across devices, which is both slower and riskier for correctness/memory.

Reproduction:

from diffusers import LongCatImageTransformer2DModel
print(getattr(LongCatImageTransformer2DModel, "_no_split_modules", None))
# None

Relevant precedent:
FluxTransformer2DModel and QwenImageTransformer2DModel set _no_split_modules.

Suggested fix:

_no_split_modules = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]

Issue 8: Pipeline latent channels are hardcoded to 16

Affected code:

# 4. Prepare latent variables
num_channels_latents = 16
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)

# 4. Prepare latent variables
num_channels_latents = 16
latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents(
image,
batch_size * num_images_per_prompt,
num_channels_latents,
calculated_height,
calculated_width,
prompt_embeds.dtype,
prompt_embeds.shape[1],
device,
generator,
latents,
)

Problem:
Both pipelines use num_channels_latents = 16 instead of deriving it from self.transformer.config.in_channels // 4.

Impact:
Tiny/custom configs and future compatible checkpoints with different packed channel sizes fail with shape mismatches.

Reproduction:

from diffusers import LongCatImageTransformer2DModel

model = LongCatImageTransformer2DModel(in_channels=8, num_layers=1, num_single_layers=1,
    attention_head_dim=6, num_attention_heads=1, joint_attention_dim=8, axes_dims_rope=[2, 2, 2])
print(model.config.in_channels // 4)
# 2, but the pipelines always use 16

Relevant precedent:
Flux and QwenImage derive latent channels from self.transformer.config.in_channels // 4.

Suggested fix:

num_channels_latents = self.transformer.config.in_channels // 4

Issue 9: AutoPipeline mappings omit LongCat image pipelines

Affected code:

AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", StableDiffusionPipeline),
("stable-diffusion-xl", StableDiffusionXLPipeline),
("stable-diffusion-3", StableDiffusion3Pipeline),
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
("stable-diffusion-3-controlnet", StableDiffusion3ControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
("cascade", StableCascadeCombinedPipeline),
("lcm", LatentConsistencyModelPipeline),
("pixart-alpha", PixArtAlphaPipeline),
("pixart-sigma", PixArtSigmaPipeline),
("sana", SanaPipeline),
("sana-pag", SanaPAGPipeline),
("stable-diffusion-pag", StableDiffusionPAGPipeline),
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("flux-kontext", FluxKontextPipeline),
("flux2-klein", Flux2KleinPipeline),
("flux2", Flux2Pipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImagePipeline),
("helios", HeliosPipeline),
("helios-pyramid", HeliosPyramidPipeline),
("cogview4-control", CogView4ControlPipeline),
("nucleusmoe-image", NucleusMoEImagePipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),
("z-image", ZImagePipeline),
("z-image-controlnet", ZImageControlNetPipeline),
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
("z-image-omni", ZImageOmniPipeline),
("ovis", OvisImagePipeline),
("prx", PRXPipeline),

AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
[
("stable-diffusion", StableDiffusionImg2ImgPipeline),
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline),
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
("flux-kontext", FluxKontextPipeline),
("flux2-klein", Flux2KleinPipeline),
("flux2", Flux2Pipeline),
("qwenimage", QwenImageImg2ImgPipeline),
("qwenimage-edit", QwenImageEditPipeline),
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
("qwenimage-layered", QwenImageLayeredPipeline),
("z-image", ZImageImg2ImgPipeline),

Problem:
LongCatImagePipeline and LongCatImageEditPipeline are public top-level imports but are absent from AUTO_TEXT2IMAGE_PIPELINES_MAPPING and AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.

Impact:
AutoPipelineForText2Image/AutoPipelineForImage2Image cannot resolve LongCat checkpoints by class name.

Reproduction:

from diffusers import LongCatImagePipeline
from diffusers.pipelines.auto_pipeline import AUTO_TEXT2IMAGE_PIPELINES_MAPPING

print(LongCatImagePipeline in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.values())
# False

Relevant precedent:
Flux and QwenImage families are registered in the corresponding AutoPipeline maps.

Suggested fix:
Import LongCat pipelines in auto_pipeline.py and add:

("longcat-image", LongCatImagePipeline)
("longcat-image-edit", LongCatImageEditPipeline)

Issue 10: No LongCat image fast or slow tests exist

Affected code:
tests/ has no longcat_image model or pipeline test file at this commit.

Problem:
There are no fast tests for top-level imports, save/load, tiny transformer forward, prompt-embed paths, CFG batching, dtype handling, or pipeline serialization. There are also no slow tests for the real LongCat-Image or LongCat-Image-Edit checkpoints.

Impact:
The current regressions are not covered, and the required slow coverage is missing.

Reproduction:

from pathlib import Path
print(list(Path("tests").rglob("*longcat_image*")))
# []

Relevant precedent:
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py includes fast and slow LongCat audio coverage.

Suggested fix:
Add tests/models/transformers/test_models_transformer_longcat_image.py and tests/pipelines/longcat_image/test_longcat_image.py, including slow tests guarded by env/model availability for both text-to-image and edit checkpoints.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions