Skip to content

kandinsky5 model/pipeline review #13639

@hlky

Description

@hlky

kandinsky5 model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search performed against huggingface/diffusers Issues and PRs for kandinsky5, affected class names, output fields, prompt embedding batching, I2I tensor inputs, docs examples, return_dict, dtype/offload behavior, _no_split_modules, and slow-test coverage. Existing related items are noted inline. No GitHub issue was created before this CREATE ISSUE request.

Issue 1: Precomputed prompt embeddings break batching and CFG

Affected code:

if prompt is not None and isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds_qwen.shape[0]
# 3. Encode input prompt
if prompt_embeds_qwen is None:
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
prompt=prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if self.guidance_scale > 1.0:
if negative_prompt is None:
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
elif len(negative_prompt) != len(prompt):
raise ValueError(
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
self.encode_prompt(
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_visual_dim
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents,
)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds_qwen.shape[0]
# 3. Encode input prompt
if prompt_embeds_qwen is None:
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if self.guidance_scale > 1.0:
if negative_prompt is None:
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
elif len(negative_prompt) != len(prompt):
raise ValueError(
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
self.encode_prompt(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables with image conditioning
num_channels_latents = self.transformer.config.in_visual_dim
latents = self.prepare_latents(
image=image,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames=num_frames,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds_qwen.shape[0]
# 3. Encode input prompt
if prompt_embeds_qwen is None:
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if self.guidance_scale > 1.0:
if negative_prompt is None:
negative_prompt = ""
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
elif len(negative_prompt) != len(prompt):
raise ValueError(
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
self.encode_prompt(
prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_visual_dim
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
)

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds_qwen.shape[0]
# 3. Encode input prompt
if prompt_embeds_qwen is None:
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
prompt=prompt,
image=image,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if self.guidance_scale > 1.0:
if negative_prompt is None:
negative_prompt = ""
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
elif len(negative_prompt) != len(prompt):
raise ValueError(
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
self.encode_prompt(
prompt=negative_prompt,
image=image,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables with image conditioning
num_channels_latents = self.transformer.config.in_visual_dim
latents = self.prepare_latents(
image=image,
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
)

Problem:
When prompt_embeds_qwen is supplied, the pipelines skip encode_prompt, so embeddings and prompt_cu_seqlens are not expanded for num_images_per_prompt / num_videos_per_prompt, but latents are. CFG also sizes default/string negative prompts from len(prompt), which fails when prompt=None and embeddings are used.

Impact:
Valid precomputed-embedding calls either crash with batch mismatches or fail before denoising. This affects all four Kandinsky5 pipelines.

Reproduction:

import torch
from types import SimpleNamespace
from diffusers import FlowMatchEulerDiscreteScheduler, Kandinsky5T2IPipeline

class ReproPipeline(Kandinsky5T2IPipeline):
    @property
    def _execution_device(self):
        return torch.device("cpu")

class DummyModule(torch.nn.Module):
    @property
    def dtype(self):
        return torch.float32

class DummyTransformer(DummyModule):
    config = SimpleNamespace(in_visual_dim=4)
    visual_cond = False
    def forward(self, hidden_states, encoder_hidden_states, **kwargs):
        assert hidden_states.shape[0] == encoder_hidden_states.shape[0], (
            hidden_states.shape,
            encoder_hidden_states.shape,
        )
        return SimpleNamespace(sample=torch.zeros_like(hidden_states[..., :4]))

class DummyVAE(DummyModule):
    config = SimpleNamespace(scaling_factor=1.0)

pipe = ReproPipeline(DummyTransformer(), DummyVAE(), DummyModule(), None, DummyModule(), None, FlowMatchEulerDiscreteScheduler())
pipe.resolutions = [(64, 64)]
pipe.set_progress_bar_config(disable=True)

emb = torch.zeros(2, 4, 8)
pooled = torch.zeros(2, 6)
cu = torch.tensor([0, 4, 8], dtype=torch.int32)

pipe(
    prompt=None,
    prompt_embeds_qwen=emb,
    prompt_embeds_clip=pooled,
    prompt_cu_seqlens=cu,
    negative_prompt_embeds_qwen=emb,
    negative_prompt_embeds_clip=pooled,
    negative_prompt_cu_seqlens=cu,
    height=64,
    width=64,
    guidance_scale=4.0,
    num_images_per_prompt=2,
    num_inference_steps=1,
    output_type="latent",
)

Relevant precedent:

prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)

Suggested fix:

# Route provided embeds through a helper that mirrors encode_prompt's repeat logic.
if prompt_embeds_qwen is not None:
    prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self._repeat_prompt_embeds(
        prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens, num_images_per_prompt, device
    )

negative_batch_size = batch_size
if isinstance(negative_prompt, str):
    negative_prompt = [negative_prompt] * negative_batch_size
elif negative_prompt is not None and len(negative_prompt) != negative_batch_size:
    raise ValueError(...)

Issue 2: Image pipelines return .image instead of the standard .images

Affected code:

@dataclass
class KandinskyImagePipelineOutput(BaseOutput):
r"""
Output class for kandinsky image pipelines.
Args:
image (`torch.Tensor`, `np.ndarray`, or list[PIL.Image.Image]):
List of image outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image. It can also be a NumPy array or Torch tensor of shape `(batch_size, channels, height,
width)`.
"""
image: torch.Tensor

if not return_dict:
return (image,)
return KandinskyImagePipelineOutput(image=image)

if not return_dict:
return (image,)
return KandinskyImagePipelineOutput(image=image)

Problem:
Kandinsky5 image pipelines return KandinskyImagePipelineOutput(image=...). Diffusers image pipelines conventionally return an images field. The source docstring examples also call .frames[0], which is neither the actual field nor the image-pipeline convention.

Impact:
User code expecting standard Diffusers output (pipe(...).images) fails, and generated docs from source examples are misleading.

Reproduction:

import torch
from diffusers.pipelines.kandinsky5.pipeline_output import KandinskyImagePipelineOutput

out = KandinskyImagePipelineOutput(image=torch.zeros(1, 3, 4, 4))
print(list(out.keys()))
print(hasattr(out, "images"), hasattr(out, "frames"))

Relevant precedent:

class ImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
"""
images: list[PIL.Image.Image] | np.ndarray

Suggested fix:

from ..pipeline_utils import ImagePipelineOutput

# T2I / I2I return path
return ImagePipelineOutput(images=image)

Update tests and docs to use .images[0].

Issue 3: I2I advertises PipelineImageInput but only handles PIL images in prompt encoding

Affected code:

def _encode_prompt_qwen(
self,
prompt: list[str],
image: PipelineImageInput | None = None,
device: torch.device | None = None,
max_sequence_length: int = 1024,
dtype: torch.dtype | None = None,
):
"""
Encode prompt using Qwen2.5-VL text encoder.
This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
image generation.
Args:
prompt list[str]: Input list of prompts
image (PipelineImageInput): Input list of images to condition the generation on
device (torch.device): Device to run encoding on
max_sequence_length (int): Maximum sequence length for tokenization
dtype (torch.dtype): Data type for embeddings
Returns:
tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(image, list):
image = [image]
image = [i.resize((i.size[0] // 2, i.size[1] // 2)) for i in image]

if height is None and width is None:
width, height = image[0].size if isinstance(image, list) else image.size

Problem:
Kandinsky5I2IPipeline.__call__ and _encode_prompt_qwen use PIL-only .size and .resize access. The public type is PipelineImageInput, and the docstring says tensors are accepted, but tensor/NumPy inputs fail before preprocessing can normalize them.

Impact:
Valid Diffusers image input types are rejected for I2I, despite being accepted by the VAE image processor path.

Reproduction:

import torch
from diffusers import Kandinsky5I2IPipeline

pipe = object.__new__(Kandinsky5I2IPipeline)
pipe._encode_prompt_qwen(
    prompt=["edit it"],
    image=torch.zeros(1, 3, 64, 64),
    device=torch.device("cpu"),
    dtype=torch.float32,
)

Relevant precedent:
VaeImageProcessor.preprocess is already used later for the VAE path; the Qwen image path should normalize the same public input types before resizing.

Suggested fix:

if not isinstance(image, list):
    image = [image]

# Normalize non-PIL inputs before calling Qwen processor.
image = [self.image_processor.numpy_to_pil(i)[0] if not hasattr(i, "resize") else i for i in image]
image = [i.resize((i.size[0] // 2, i.size[1] // 2)) for i in image]

Issue 4: Transformer return, dtype, and device-map contracts are inconsistent

Affected code:

def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(torch.bfloat16)

_repeated_blocks = [
"Kandinsky5TransformerEncoderBlock",
"Kandinsky5TransformerDecoderBlock",
]
_keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
_supports_gradient_checkpointing = True

if not return_dict:
return x
return Transformer2DModelOutput(sample=x)

Problem:
return_dict=False returns a bare tensor instead of a one-element tuple. The rotary helper hard-casts through torch.bfloat16, quantizing non-bf16 runs. The class also declares repeated blocks but no _no_split_modules.

Impact:
The public model return contract differs from related transformers, dtype behavior is not clean for fp32/fp16, and device_map="auto" lacks block no-split guidance.

Reproduction:

import torch
from diffusers import Kandinsky5Transformer3DModel

model = Kandinsky5Transformer3DModel(
    in_visual_dim=4, in_text_dim=8, in_text_dim2=6, time_dim=8, out_visual_dim=4,
    patch_size=(1, 1, 1), model_dim=8, ff_dim=16, num_text_blocks=1,
    num_visual_blocks=1, axes_dims=(2, 2, 4), attention_type="regular",
).eval()

out = model(
    hidden_states=torch.randn(1, 1, 2, 2, 4),
    encoder_hidden_states=torch.randn(1, 3, 8),
    timestep=torch.tensor([1.0]),
    pooled_projections=torch.randn(1, 6),
    visual_rope_pos=[torch.arange(1), torch.arange(2), torch.arange(2)],
    text_rope_pos=torch.arange(3),
    return_dict=False,
)
print(type(out), getattr(out, "shape", None))
print(getattr(Kandinsky5Transformer3DModel, "_no_split_modules", None))

Relevant precedent:

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

Duplicate status:
The dtype/no-split parts overlap with #13597 and the older dtype/autocast fix in #12814 / #12809. The return_dict=False tuple issue was not found as a duplicate.

Suggested fix:

# Rotary helper
orig_dtype = x.dtype
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).float()
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(orig_dtype)

# Class metadata
_no_split_modules = ["Kandinsky5TransformerEncoderBlock", "Kandinsky5TransformerDecoderBlock"]

# return_dict=False
if not return_dict:
    return (x,)

Issue 5: I2V documentation example is not runnable

Affected code:

```python
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
pipeline.transformer.set_attention_backend("flex") # <--- Set attention bakend to Flex
pipeline.enable_model_cpu_offload() # <--- Enable cpu offloading for single GPU inference
pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=True) # <--- Compile with max-autotune-no-cudagraphs
# Generate video
image = load_image(
"https://huggingface.co/kandinsky-community/kandinsky-3/resolve/main/assets/title.jpg?download=true"
)
height = 896
width = 896
image = image.resize((width, height))
prompt = "An funny furry creture smiles happily and holds a sign that says 'Kandinsky'"
negative_prompt = ""
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]

Problem:
The I2V example imports and instantiates Kandinsky5T2VPipeline for an I2V checkpoint, uses pipeline.* even though the variable is named pipe, uses load_image without importing it, and never passes image=image to the call.

Impact:
Users following the docs cannot run image-to-video inference.

Reproduction:

from pathlib import Path

text = Path("docs/source/en/api/pipelines/kandinsky5_video.md").read_text()
section = text.split("### Basic Image-to-Video Generation", 1)[1].split("## Kandinsky5T2VPipeline", 1)[0]

assert "from diffusers import Kandinsky5I2VPipeline" in section
assert "from diffusers.utils import export_to_video, load_image" in section
assert "pipe.enable_model_cpu_offload()" in section
assert "image=image" in section

Relevant precedent:
The source docstring for Kandinsky5I2VPipeline uses the correct pipeline class and imports:

```python
>>> import torch
>>> from diffusers import Kandinsky5I2VPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> # Available models:
>>> # kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers
>>> model_id = "kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers"
>>> pipe = Kandinsky5I2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe = pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
... )
>>> prompt = "An astronaut floating in space with Earth in the background, cinematic shot"
>>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
>>> output = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=512,
... width=768,
... num_frames=121,
... num_inference_steps=50,
... guidance_scale=5.0,
... ).frames[0]

Suggested fix:

from diffusers import Kandinsky5I2VPipeline
from diffusers.utils import export_to_video, load_image

pipe = Kandinsky5I2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.transformer.set_attention_backend("flex")
pipe.enable_model_cpu_offload()

output = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=121,
    num_inference_steps=50,
    guidance_scale=5.0,
).frames[0]

Issue 6: Slow tests are missing and several fast coverage paths are skipped

Affected code:

@unittest.skip("Only SDPA or NABLA (flex)")
def test_xformers_memory_efficient_attention(self):
pass
@unittest.skip("TODO:Test does not work")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("TODO: revisit")
def test_inference_batch_single_identical(self):

@unittest.skip("TODO:Test does not work")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("TODO: revisit")
def test_callback_inputs(self):
pass
@unittest.skip("TODO: revisit")
def test_inference_batch_single_identical(self):

@unittest.skip("TODO: Test does not work")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("TODO: revisit, Batch isnot yet supported in this pipeline")
def test_num_images_per_prompt(self):
pass
@unittest.skip("TODO: revisit, Batch isnot yet supported in this pipeline")
def test_inference_batch_single_identical(self):
pass
@unittest.skip("TODO: revisit, Batch isnot yet supported in this pipeline")
def test_inference_batch_consistent(self):
pass
@unittest.skip("TODO: revisit, not working")
def test_float16_inference(self):

@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("Only SDPA or NABLA (flex)")
def test_xformers_memory_efficient_attention(self):
pass
@unittest.skip("All encoders are needed")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Meant for eiter FP32 or BF16 inference")
def test_float16_inference(self):
pass

Problem:
There are fast tests for all four pipelines, but no @slow or integration tests for Kandinsky5. Multiple fast tests for encode-prompt isolation, callback inputs, batch behavior, num_images_per_prompt, and float16 are skipped.

Impact:
The public failures above are not covered. In particular, embedding-only calls, tensor I2I inputs, callback behavior, and official checkpoint smoke tests are missing.

Reproduction:

from pathlib import Path

for path in sorted(Path("tests/pipelines/kandinsky5").glob("test_*.py")):
    text = path.read_text(encoding="utf-8")
    print(path.as_posix(), "slow=", "@slow" in text, "nightly=", "@nightly" in text, "skip=", "@unittest.skip" in text)

print([
    p.as_posix()
    for p in Path("tests/models").rglob("test_*.py")
    if "Kandinsky5Transformer3DModel" in p.read_text(errors="ignore")
])

Relevant precedent:
Existing Kandinsky families include slow/integration coverage:

class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_kandinskyV3(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, num_inference_steps=5, generator=generator).images[0]

class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_kandinsky_text2img(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/kandinskyv22/kandinskyv22_text2img_cat_fp16.npy"
)
pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16
)
pipe_prior.enable_model_cpu_offload(device=torch_device)
pipeline = KandinskyV22Pipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload(device=torch_device)
pipeline.set_progress_bar_config(disable=None)
prompt = "red cat, 4k photo"
generator = torch.Generator(device="cpu").manual_seed(0)
image_emb, zero_image_emb = pipe_prior(
prompt,
generator=generator,

Suggested fix:
Add slow smoke tests for the published T2V, I2V, T2I, and I2I checkpoints, and unskip or replace the fast tests for encode-prompt isolation, callbacks, batching, num_images_per_prompt, and dtype behavior. Local .venv fast-test execution currently fails at collection because the installed torch lacks torch._C._distributed_c10d, so I could not run the shared PipelineTesterMixin suite end to end in this environment.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions