longcat_image model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Duplicate search checked GitHub issues/PRs for longcat_image, LongCatImagePipeline, LongCatImageEditPipeline, LongCatImageTransformer2DModel, prompt_embeds, negative_prompt, randn_tensor, AutoPipeline, and LoRA failures. Existing overlaps are noted below.
Issue 1: prompt_embeds-only calls crash in both pipelines
Affected code:
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
batch_size = len(prompt) |
|
# If prompt_embeds is provided and prompt is None, skip encoding |
|
if prompt_embeds is None: |
|
prompt_embeds = self._encode_prompt(prompt) |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
batch_size = len(prompt) |
|
# If prompt_embeds is provided and prompt is None, skip encoding |
|
if prompt_embeds is None: |
|
prompt_embeds = self._encode_prompt(prompt, image) |
Problem:
check_inputs() allows prompt=None with prompt_embeds, but encode_prompt() immediately does len(prompt). The text-to-image pipeline also calls prompt rewriting before encoding unless users manually set enable_prompt_rewrite=False.
Impact:
Users cannot use precomputed prompt embeddings, despite the public prompt_embeds parameters.
Reproduction:
import torch
from diffusers import LongCatImagePipeline
pipe = LongCatImagePipeline.__new__(LongCatImagePipeline)
pipe.encode_prompt(prompt=None, prompt_embeds=torch.zeros(1, 2, 8))
# TypeError: object of type 'NoneType' has no len()
Relevant precedent:
QwenImagePipeline.encode_prompt derives batch size from prompt_embeds when prompt is absent.
Suggested fix:
if prompt_embeds is not None and prompt is None:
batch_size = prompt_embeds.shape[0]
else:
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
Issue 2: Batched prompts break CFG when negative_prompt is omitted
Affected code:
|
negative_prompt = "" if negative_prompt is None else negative_prompt |
|
(prompt_embeds, text_ids) = self.encode_prompt( |
|
prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt |
|
) |
|
if self.do_classifier_free_guidance: |
|
(negative_prompt_embeds, negative_text_ids) = self.encode_prompt( |
|
prompt=negative_prompt, |
|
prompt_embeds=negative_prompt_embeds, |
|
num_images_per_prompt=num_images_per_prompt, |
|
) |
Problem:
For prompt=["a", "b"], the pipeline sets negative_prompt = "", so unconditional embeddings have batch size 1 while latents and conditional embeddings have batch size 2.
Impact:
Batched text-to-image generation with default CFG can fail or produce invalid conditioning.
Reproduction:
prompt = ["a", "b"]
negative_prompt = ""
print(len(prompt), 1 if isinstance(negative_prompt, str) else len(negative_prompt))
# 2 1
Relevant precedent:
Flux/Qwen-style pipelines normalize a scalar negative prompt to match prompt batch size.
Suggested fix:
if negative_prompt is None:
negative_prompt = [""] * batch_size if batch_size > 1 else ""
elif isinstance(negative_prompt, str) and batch_size > 1:
negative_prompt = [negative_prompt] * batch_size
Issue 3: Text-to-image noise is generated in default float32, then cast
Affected code:
|
latents = randn_tensor(shape, generator=generator, device=device) |
|
latents = latents.to(dtype=dtype) |
Problem:
randn_tensor() is called without dtype=dtype, then cast afterward. The edit pipeline already passes dtype.
Impact:
bf16/fp16 inference gets different initial noise than a direct low-precision draw, which can hurt parity and wastes memory.
Reproduction:
import torch
import diffusers.pipelines.longcat_image.pipeline_longcat_image as mod
from diffusers import LongCatImagePipeline
seen = {}
def spy(shape, **kwargs):
seen["dtype"] = kwargs.get("dtype")
return torch.zeros(shape)
mod.randn_tensor = spy
pipe = LongCatImagePipeline.__new__(LongCatImagePipeline)
pipe.vae_scale_factor = 8
pipe.tokenizer_max_length = 512
pipe.prepare_latents(1, 16, 32, 32, torch.bfloat16, "cpu", None)
print(seen["dtype"])
# None
Relevant precedent:
FluxPipeline.prepare_latents passes dtype=dtype.
Suggested fix:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Issue 4: Edit prompt truncation warning crashes
Affected code:
|
if len(all_tokens) > self.tokenizer_max_length: |
|
logger.warning( |
|
"Your input was truncated because `max_sequence_length` is set to " |
|
f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" |
|
) |
|
all_tokens = all_tokens[: self.tokenizer_max_length] |
Problem:
The warning formats len(len(all_tokens)), which raises when the truncation branch is reached.
Impact:
Long edit prompts crash instead of being truncated with a warning.
Reproduction:
all_tokens = [1] * 513
print(len(len(all_tokens)))
# TypeError: object of type 'int' has no len()
Relevant precedent:
The sibling text-to-image pipeline uses len(all_tokens) correctly. Duplicate already exists: #13526
Suggested fix:
f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}"
Issue 5: Edit image position IDs are forced to float64
Affected code:
|
image_latents_ids = prepare_pos_ids( |
|
modality_id=2, |
|
type="image", |
|
start=(prompt_embeds_length, prompt_embeds_length), |
|
height=height // 2, |
|
width=width // 2, |
|
).to(device, dtype=torch.float64) |
Problem:
image_latents_ids is created with dtype=torch.float64. The text pipeline does not do this, and the model immediately casts position IDs to float32 internally.
Impact:
This can fail or fall back on backends with poor/no float64 support, including MPS/NPU, and creates unnecessary dtype divergence.
Reproduction:
import torch
from diffusers.pipelines.longcat_image.pipeline_longcat_image_edit import prepare_pos_ids
ids = prepare_pos_ids(modality_id=2, type="image", height=2, width=2).to("cpu", dtype=torch.float64)
print(ids.dtype)
# torch.float64
Relevant precedent:
LongCatImagePipeline.prepare_latents keeps position IDs at the default dtype.
Suggested fix:
Issue 6: LoRA/joint attention integration is incomplete
Affected code:
|
class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin): |
|
r""" |
|
The pipeline for text-to-image generation. |
|
""" |
|
|
|
model_cpu_offload_seq = "text_encoder->transformer->vae" |
|
_optional_components = [] |
|
_callback_tensor_inputs = ["latents", "prompt_embeds"] |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
timestep: torch.LongTensor = None, |
|
img_ids: torch.Tensor = None, |
|
txt_ids: torch.Tensor = None, |
|
guidance: torch.Tensor = None, |
|
return_dict: bool = True, |
|
) -> torch.FloatTensor | Transformer2DModelOutput: |
Problem:
The docs show a LoRA badge, but both pipelines lack a LoRA loader mixin, and the transformer forward does not accept or forward joint_attention_kwargs even though blocks/processors support it.
Impact:
pipe.load_lora_weights() is unavailable and runtime attention kwargs cannot reach processors.
Reproduction:
from diffusers import LongCatImagePipeline, LongCatImageTransformer2DModel
import inspect
print(hasattr(LongCatImagePipeline, "load_lora_weights"))
print("joint_attention_kwargs" in inspect.signature(LongCatImageTransformer2DModel.forward).parameters)
# False
# False
Relevant precedent:
QwenImagePipeline inherits QwenImageLoraLoaderMixin; FluxTransformer2DModel.forward accepts joint_attention_kwargs. Duplicate LoRA issue/PR: #12859 and #12867
Suggested fix:
Add a LongCat-specific LoRA loader mixin, inherit it in both pipelines, add joint_attention_kwargs to transformer forward, and pass it through each block.
Issue 7: Transformer is missing _no_split_modules
Affected code:
|
class LongCatImageTransformer2DModel( |
|
ModelMixin, |
|
ConfigMixin, |
|
PeftAdapterMixin, |
|
FromOriginalModelMixin, |
|
CacheMixin, |
|
AttentionMixin, |
|
): |
|
""" |
|
The Transformer model introduced in Longcat-Image. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"] |
Problem:
The model declares repeated block classes but not _no_split_modules.
Impact:
device_map/offload placement can split residual attention blocks across devices, which is both slower and riskier for correctness/memory.
Reproduction:
from diffusers import LongCatImageTransformer2DModel
print(getattr(LongCatImageTransformer2DModel, "_no_split_modules", None))
# None
Relevant precedent:
FluxTransformer2DModel and QwenImageTransformer2DModel set _no_split_modules.
Suggested fix:
_no_split_modules = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
Issue 8: Pipeline latent channels are hardcoded to 16
Affected code:
|
# 4. Prepare latent variables |
|
num_channels_latents = 16 |
|
latents, latent_image_ids = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
# 4. Prepare latent variables |
|
num_channels_latents = 16 |
|
latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents( |
|
image, |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
calculated_height, |
|
calculated_width, |
|
prompt_embeds.dtype, |
|
prompt_embeds.shape[1], |
|
device, |
|
generator, |
|
latents, |
|
) |
Problem:
Both pipelines use num_channels_latents = 16 instead of deriving it from self.transformer.config.in_channels // 4.
Impact:
Tiny/custom configs and future compatible checkpoints with different packed channel sizes fail with shape mismatches.
Reproduction:
from diffusers import LongCatImageTransformer2DModel
model = LongCatImageTransformer2DModel(in_channels=8, num_layers=1, num_single_layers=1,
attention_head_dim=6, num_attention_heads=1, joint_attention_dim=8, axes_dims_rope=[2, 2, 2])
print(model.config.in_channels // 4)
# 2, but the pipelines always use 16
Relevant precedent:
Flux and QwenImage derive latent channels from self.transformer.config.in_channels // 4.
Suggested fix:
num_channels_latents = self.transformer.config.in_channels // 4
Issue 9: AutoPipeline mappings omit LongCat image pipelines
Affected code:
|
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( |
|
[ |
|
("stable-diffusion", StableDiffusionPipeline), |
|
("stable-diffusion-xl", StableDiffusionXLPipeline), |
|
("stable-diffusion-3", StableDiffusion3Pipeline), |
|
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline), |
|
("if", IFPipeline), |
|
("hunyuan", HunyuanDiTPipeline), |
|
("hunyuan-pag", HunyuanDiTPAGPipeline), |
|
("kandinsky", KandinskyCombinedPipeline), |
|
("kandinsky22", KandinskyV22CombinedPipeline), |
|
("kandinsky3", Kandinsky3Pipeline), |
|
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), |
|
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), |
|
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline), |
|
("stable-diffusion-3-controlnet", StableDiffusion3ControlNetPipeline), |
|
("wuerstchen", WuerstchenCombinedPipeline), |
|
("cascade", StableCascadeCombinedPipeline), |
|
("lcm", LatentConsistencyModelPipeline), |
|
("pixart-alpha", PixArtAlphaPipeline), |
|
("pixart-sigma", PixArtSigmaPipeline), |
|
("sana", SanaPipeline), |
|
("sana-pag", SanaPAGPipeline), |
|
("stable-diffusion-pag", StableDiffusionPAGPipeline), |
|
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline), |
|
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline), |
|
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline), |
|
("pixart-sigma-pag", PixArtSigmaPAGPipeline), |
|
("auraflow", AuraFlowPipeline), |
|
("flux", FluxPipeline), |
|
("flux-control", FluxControlPipeline), |
|
("flux-controlnet", FluxControlNetPipeline), |
|
("flux-kontext", FluxKontextPipeline), |
|
("flux2-klein", Flux2KleinPipeline), |
|
("flux2", Flux2Pipeline), |
|
("lumina", LuminaPipeline), |
|
("lumina2", Lumina2Pipeline), |
|
("chroma", ChromaPipeline), |
|
("cogview3", CogView3PlusPipeline), |
|
("cogview4", CogView4Pipeline), |
|
("glm_image", GlmImagePipeline), |
|
("helios", HeliosPipeline), |
|
("helios-pyramid", HeliosPyramidPipeline), |
|
("cogview4-control", CogView4ControlPipeline), |
|
("nucleusmoe-image", NucleusMoEImagePipeline), |
|
("qwenimage", QwenImagePipeline), |
|
("qwenimage-controlnet", QwenImageControlNetPipeline), |
|
("z-image", ZImagePipeline), |
|
("z-image-controlnet", ZImageControlNetPipeline), |
|
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline), |
|
("z-image-omni", ZImageOmniPipeline), |
|
("ovis", OvisImagePipeline), |
|
("prx", PRXPipeline), |
|
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( |
|
[ |
|
("stable-diffusion", StableDiffusionImg2ImgPipeline), |
|
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline), |
|
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline), |
|
("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline), |
|
("if", IFImg2ImgPipeline), |
|
("kandinsky", KandinskyImg2ImgCombinedPipeline), |
|
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline), |
|
("kandinsky3", Kandinsky3Img2ImgPipeline), |
|
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), |
|
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline), |
|
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), |
|
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), |
|
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), |
|
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), |
|
("lcm", LatentConsistencyModelImg2ImgPipeline), |
|
("flux", FluxImg2ImgPipeline), |
|
("flux-controlnet", FluxControlNetImg2ImgPipeline), |
|
("flux-control", FluxControlImg2ImgPipeline), |
|
("flux-kontext", FluxKontextPipeline), |
|
("flux2-klein", Flux2KleinPipeline), |
|
("flux2", Flux2Pipeline), |
|
("qwenimage", QwenImageImg2ImgPipeline), |
|
("qwenimage-edit", QwenImageEditPipeline), |
|
("qwenimage-edit-plus", QwenImageEditPlusPipeline), |
|
("qwenimage-layered", QwenImageLayeredPipeline), |
|
("z-image", ZImageImg2ImgPipeline), |
Problem:
LongCatImagePipeline and LongCatImageEditPipeline are public top-level imports but are absent from AUTO_TEXT2IMAGE_PIPELINES_MAPPING and AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.
Impact:
AutoPipelineForText2Image/AutoPipelineForImage2Image cannot resolve LongCat checkpoints by class name.
Reproduction:
from diffusers import LongCatImagePipeline
from diffusers.pipelines.auto_pipeline import AUTO_TEXT2IMAGE_PIPELINES_MAPPING
print(LongCatImagePipeline in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.values())
# False
Relevant precedent:
Flux and QwenImage families are registered in the corresponding AutoPipeline maps.
Suggested fix:
Import LongCat pipelines in auto_pipeline.py and add:
("longcat-image", LongCatImagePipeline)
("longcat-image-edit", LongCatImageEditPipeline)
Issue 10: No LongCat image fast or slow tests exist
Affected code:
tests/ has no longcat_image model or pipeline test file at this commit.
Problem:
There are no fast tests for top-level imports, save/load, tiny transformer forward, prompt-embed paths, CFG batching, dtype handling, or pipeline serialization. There are also no slow tests for the real LongCat-Image or LongCat-Image-Edit checkpoints.
Impact:
The current regressions are not covered, and the required slow coverage is missing.
Reproduction:
from pathlib import Path
print(list(Path("tests").rglob("*longcat_image*")))
# []
Relevant precedent:
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py includes fast and slow LongCat audio coverage.
Suggested fix:
Add tests/models/transformers/test_models_transformer_longcat_image.py and tests/pipelines/longcat_image/test_longcat_image.py, including slow tests guarded by env/model availability for both text-to-image and edit checkpoints.
longcat_imagemodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Duplicate search checked GitHub issues/PRs for
longcat_image,LongCatImagePipeline,LongCatImageEditPipeline,LongCatImageTransformer2DModel,prompt_embeds,negative_prompt,randn_tensor,AutoPipeline, and LoRA failures. Existing overlaps are noted below.Issue 1:
prompt_embeds-only calls crash in both pipelinesAffected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
Lines 341 to 345 in 0f1abc4
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py
Lines 351 to 355 in 0f1abc4
Problem:
check_inputs()allowsprompt=Nonewithprompt_embeds, butencode_prompt()immediately doeslen(prompt). The text-to-image pipeline also calls prompt rewriting before encoding unless users manually setenable_prompt_rewrite=False.Impact:
Users cannot use precomputed prompt embeddings, despite the public
prompt_embedsparameters.Reproduction:
Relevant precedent:
QwenImagePipeline.encode_promptderives batch size fromprompt_embedswhen prompt is absent.Suggested fix:
Issue 2: Batched prompts break CFG when
negative_promptis omittedAffected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
Lines 542 to 551 in 0f1abc4
Problem:
For
prompt=["a", "b"], the pipeline setsnegative_prompt = "", so unconditional embeddings have batch size 1 while latents and conditional embeddings have batch size 2.Impact:
Batched text-to-image generation with default CFG can fail or produce invalid conditioning.
Reproduction:
Relevant precedent:
Flux/Qwen-style pipelines normalize a scalar negative prompt to match prompt batch size.
Suggested fix:
Issue 3: Text-to-image noise is generated in default float32, then cast
Affected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
Lines 419 to 420 in 0f1abc4
Problem:
randn_tensor()is called withoutdtype=dtype, then cast afterward. The edit pipeline already passesdtype.Impact:
bf16/fp16 inference gets different initial noise than a direct low-precision draw, which can hurt parity and wastes memory.
Reproduction:
Relevant precedent:
FluxPipeline.prepare_latentspassesdtype=dtype.Suggested fix:
Issue 4: Edit prompt truncation warning crashes
Affected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py
Lines 284 to 289 in 0f1abc4
Problem:
The warning formats
len(len(all_tokens)), which raises when the truncation branch is reached.Impact:
Long edit prompts crash instead of being truncated with a warning.
Reproduction:
Relevant precedent:
The sibling text-to-image pipeline uses
len(all_tokens)correctly. Duplicate already exists: #13526Suggested fix:
f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}"Issue 5: Edit image position IDs are forced to float64
Affected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py
Lines 447 to 453 in 0f1abc4
Problem:
image_latents_idsis created withdtype=torch.float64. The text pipeline does not do this, and the model immediately casts position IDs to float32 internally.Impact:
This can fail or fall back on backends with poor/no float64 support, including MPS/NPU, and creates unnecessary dtype divergence.
Reproduction:
Relevant precedent:
LongCatImagePipeline.prepare_latentskeeps position IDs at the default dtype.Suggested fix:
Issue 6: LoRA/joint attention integration is incomplete
Affected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
Lines 205 to 213 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_longcat_image.py
Lines 466 to 475 in 0f1abc4
Problem:
The docs show a LoRA badge, but both pipelines lack a LoRA loader mixin, and the transformer forward does not accept or forward
joint_attention_kwargseven though blocks/processors support it.Impact:
pipe.load_lora_weights()is unavailable and runtime attention kwargs cannot reach processors.Reproduction:
Relevant precedent:
QwenImagePipelineinheritsQwenImageLoraLoaderMixin;FluxTransformer2DModel.forwardacceptsjoint_attention_kwargs. Duplicate LoRA issue/PR: #12859 and #12867Suggested fix:
Add a LongCat-specific LoRA loader mixin, inherit it in both pipelines, add
joint_attention_kwargsto transformerforward, and pass it through each block.Issue 7: Transformer is missing
_no_split_modulesAffected code:
diffusers/src/diffusers/models/transformers/transformer_longcat_image.py
Lines 397 to 410 in 0f1abc4
Problem:
The model declares repeated block classes but not
_no_split_modules.Impact:
device_map/offload placement can split residual attention blocks across devices, which is both slower and riskier for correctness/memory.Reproduction:
Relevant precedent:
FluxTransformer2DModelandQwenImageTransformer2DModelset_no_split_modules.Suggested fix:
Issue 8: Pipeline latent channels are hardcoded to 16
Affected code:
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py
Lines 553 to 564 in 0f1abc4
diffusers/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py
Lines 607 to 620 in 0f1abc4
Problem:
Both pipelines use
num_channels_latents = 16instead of deriving it fromself.transformer.config.in_channels // 4.Impact:
Tiny/custom configs and future compatible checkpoints with different packed channel sizes fail with shape mismatches.
Reproduction:
Relevant precedent:
Flux and QwenImage derive latent channels from
self.transformer.config.in_channels // 4.Suggested fix:
Issue 9: AutoPipeline mappings omit LongCat image pipelines
Affected code:
diffusers/src/diffusers/pipelines/auto_pipeline.py
Lines 139 to 191 in 0f1abc4
diffusers/src/diffusers/pipelines/auto_pipeline.py
Lines 195 to 222 in 0f1abc4
Problem:
LongCatImagePipelineandLongCatImageEditPipelineare public top-level imports but are absent fromAUTO_TEXT2IMAGE_PIPELINES_MAPPINGandAUTO_IMAGE2IMAGE_PIPELINES_MAPPING.Impact:
AutoPipelineForText2Image/AutoPipelineForImage2Imagecannot resolve LongCat checkpoints by class name.Reproduction:
Relevant precedent:
Flux and QwenImage families are registered in the corresponding AutoPipeline maps.
Suggested fix:
Import LongCat pipelines in
auto_pipeline.pyand add:Issue 10: No LongCat image fast or slow tests exist
Affected code:
tests/has nolongcat_imagemodel or pipeline test file at this commit.Problem:
There are no fast tests for top-level imports, save/load, tiny transformer forward, prompt-embed paths, CFG batching, dtype handling, or pipeline serialization. There are also no slow tests for the real LongCat-Image or LongCat-Image-Edit checkpoints.
Impact:
The current regressions are not covered, and the required slow coverage is missing.
Reproduction:
Relevant precedent:
tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.pyincludes fast and slow LongCat audio coverage.Suggested fix:
Add
tests/models/transformers/test_models_transformer_longcat_image.pyandtests/pipelines/longcat_image/test_longcat_image.py, including slow tests guarded by env/model availability for both text-to-image and edit checkpoints.