Skip to content

stable_video_diffusion model/pipeline review #13627

@hlky

Description

@hlky

stable_video_diffusion model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Files/categories reviewed: target pipeline/model files, lazy imports/top-level exports/dummy objects, fast and slow tests, docs, deprecation status, dtype/device/offload/callback behavior, config validation, and related video pipeline precedents. Fast and slow tests exist; no missing slow-test item. Existing coverage still skips batch consistency and fp16 inference in tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py.

Duplicate search status: searched GitHub issues and PRs for stable_video_diffusion, StableVideoDiffusionPipeline, UNetSpatioTemporalConditionModel, return_dict, tensor image/CLIP resize, guidance scale, custom latents dtype, callback tensor inputs, and tuple config validation. No exact duplicates found except the tensor-image finding is a remaining/related part of closed issue #6574 and merged PR #6999.

Issue 1: return_dict=False returns the raw frames object, not a tuple

Affected code:

Problem:
The pipeline returns frames directly when return_dict=False. Diffusers pipeline convention is to return a tuple, even for single-output pipelines. Current tests index [0], which hides the bug for batch size 1 because indexing the tensor/list returns the first video rather than the first output field.

Impact:
Callers expecting the standard tuple contract get a torch.Tensor, np.ndarray, or list directly. This breaks generic pipeline wrappers and makes pipe(..., return_dict=False)[0] mean “first batch element” instead of “frames output”.

Reproduction:

import torch
from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModelWithProjection
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler, StableVideoDiffusionPipeline, UNetSpatioTemporalConditionModel

def make_pipe():
    unet = UNetSpatioTemporalConditionModel(
        block_out_channels=(32, 64), layers_per_block=1, sample_size=32, in_channels=8, out_channels=4,
        down_block_types=("CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal"),
        up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
        cross_attention_dim=32, num_attention_heads=8,
        projection_class_embeddings_input_dim=96, addition_time_embed_dim=32,
    )
    vae = AutoencoderKLTemporalDecoder(block_out_channels=[32, 64], in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], latent_channels=4)
    image_encoder = CLIPVisionModelWithProjection(CLIPVisionConfig(hidden_size=32, projection_dim=32, num_hidden_layers=1, num_attention_heads=4, image_size=32, intermediate_size=37, patch_size=1))
    pipe = StableVideoDiffusionPipeline(vae=vae, image_encoder=image_encoder, unet=unet, scheduler=EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"), feature_extractor=CLIPImageProcessor(crop_size=32, size=32))
    pipe.set_progress_bar_config(disable=True)
    return pipe

out = make_pipe()(image=torch.rand(1, 3, 32, 32), height=32, width=32, num_frames=2, num_inference_steps=1, output_type="pt", return_dict=False)
print(type(out), isinstance(out, tuple), out.shape)
# Current: <class 'torch.Tensor'> False torch.Size([1, 2, 3, 32, 32])

Relevant precedent:

if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)

if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)

Suggested fix:

if not return_dict:
    return (frames,)

Issue 2: Decreasing guidance scales can crash CFG batching

Affected code:

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
self._guidance_scale = max_guidance_scale
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)

# 8. Prepare guidance scale
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
guidance_scale = guidance_scale.to(device, latents.dtype)
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
guidance_scale = _append_dims(guidance_scale, latents.ndim)
self._guidance_scale = guidance_scale
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# Concatenate image_latents over channels dimension
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)

Problem:
Before image and conditioning tensors are duplicated for CFG, the pipeline sets self._guidance_scale = max_guidance_scale. If min_guidance_scale > 1 but max_guidance_scale <= 1, conditioning is prepared without CFG duplication. Later, self._guidance_scale becomes the full per-frame tensor, do_classifier_free_guidance becomes true, and the denoising loop duplicates latents only. The next concat with non-duplicated image_latents fails on batch size.

Impact:
A valid decreasing guidance schedule, for example stronger first-frame guidance and no final-frame guidance, crashes at runtime.

Reproduction:

# Reuse the make_pipe() definition from Issue 1.
import torch

pipe = make_pipe()
try:
    pipe(
        image=torch.rand(1, 3, 32, 32),
        height=32,
        width=32,
        num_frames=2,
        num_inference_steps=1,
        output_type="pt",
        min_guidance_scale=2.0,
        max_guidance_scale=1.0,
    )
except Exception as e:
    print(type(e).__name__, str(e).split("\n")[0])
# RuntimeError Sizes of tensors must match except in dimension 2. Expected size 2 but got size 1 ...

Relevant precedent:
Merged PR #7143 fixed a related CFG disable regression for scalar max_guidance_scale=1, but it does not cover guidance ranges where only min_guidance_scale crosses the CFG threshold.

Suggested fix:

self._guidance_scale = max(min_guidance_scale, max_guidance_scale)

Issue 3: Tensor image inputs are not resized before CLIP encoding

Affected code:

if not isinstance(image, torch.Tensor):
image = self.video_processor.pil_to_numpy(image)
image = self.video_processor.numpy_to_pt(image)
# We normalize the image before resizing to match with the original implementation.
# Then we unnormalize it after resizing.
image = image * 2.0 - 1.0
image = _resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0
# Normalize the image with for CLIP input
image = self.feature_extractor(
images=image,
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds

Problem:
The CLIP resize path is inside if not isinstance(image, torch.Tensor). PIL/list inputs are resized to CLIP resolution before image_encoder, but tensor inputs go directly into CLIPImageProcessor(..., do_resize=False) and then CLIP. A normal SVD tensor input at generation size, such as [1, 3, 576, 1024], is therefore incompatible with the CLIP image encoder.

Impact:
The docstring allows tensor images in [0, 1], but users must secretly pre-resize tensors to the image encoder size. This is inconsistent with PIL inputs and with the closed tensor-input bug history in issue #6574 / PR #6999.

Reproduction:

# Reuse the make_pipe() definition from Issue 1.
import torch

pipe = make_pipe()
try:
    pipe(
        image=torch.rand(1, 3, 64, 64),
        height=64,
        width=64,
        num_frames=2,
        num_inference_steps=1,
        output_type="pt",
    )
except Exception as e:
    print(type(e).__name__, str(e).split("\n")[0])
# ValueError Input image size (64*64) doesn't match model (32*32).

Relevant precedent:
#6574
#6999

Suggested fix:

if not isinstance(image, torch.Tensor):
    image = self.video_processor.pil_to_numpy(image)
    image = self.video_processor.numpy_to_pt(image)

image = image * 2.0 - 1.0
image = _resize_with_antialiasing(image, (224, 224))
image = (image + 1.0) / 2.0

Issue 4: Custom latents are moved to device but not cast to pipeline dtype

Affected code:

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

Problem:
Generated latents use dtype=image_embeddings.dtype, but user-provided latents only call latents.to(device). In a half-precision pipeline, float32 custom latents promote the concatenated UNet input to float32, while UNet weights are float16.

Impact:
Supplying precomputed fp32 latents to an fp16 pipeline can crash with dtype mismatches instead of being normalized to the pipeline’s working dtype.

Reproduction:

# Reuse the make_pipe() definition from Issue 1.
import torch

pipe = make_pipe().to(dtype=torch.float16)
latents = torch.randn(1, 2, 4, 16, 16, dtype=torch.float32)
try:
    pipe(image=torch.rand(1, 3, 32, 32), height=32, width=32, num_frames=2, num_inference_steps=1, output_type="pt", latents=latents)
except Exception as e:
    print(type(e).__name__, str(e).split("\n")[0])
# RuntimeError mat1 and mat2 must have the same dtype, but got Float and Half

Relevant precedent:

if latents is not None:
return latents.to(device=device, dtype=dtype)

if latents is not None:
return latents.to(device=device, dtype=dtype)

Suggested fix:

else:
    latents = latents.to(device=device, dtype=dtype)

Issue 5: Tuple config length validation is incomplete in the spatio-temporal UNet

Affected code:

if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
padding=1,
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=1e-5,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],

up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,

Problem:
cross_attention_dim is validated only when it is a list, not a tuple, even though tuples are accepted and tested. transformer_layers_per_block is expanded when it is an int, but non-int sequence lengths are not validated before indexed access.

Impact:
Bad configs fail with IndexError: tuple index out of range during construction instead of a clear config ValueError. Longer tuples can also silently carry unused entries.

Reproduction:

from diffusers import UNetSpatioTemporalConditionModel

try:
    UNetSpatioTemporalConditionModel(
        block_out_channels=(32, 64), layers_per_block=1, sample_size=32, in_channels=8, out_channels=4,
        down_block_types=("CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal"),
        up_block_types=("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal"),
        cross_attention_dim=(32,), num_attention_heads=8,
        projection_class_embeddings_input_dim=96, addition_time_embed_dim=32,
    )
except Exception as e:
    print(type(e).__name__, str(e))
# IndexError tuple index out of range

Relevant precedent:
The same initializer already validates num_attention_heads and layers_per_block sequence lengths before indexing:

if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."

Suggested fix:

if not isinstance(cross_attention_dim, int) and len(cross_attention_dim) != len(down_block_types):
    raise ValueError(
        f"Must provide the same number of `cross_attention_dim` as `down_block_types`. "
        f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
    )

if not isinstance(transformer_layers_per_block, int) and len(transformer_layers_per_block) != len(down_block_types):
    raise ValueError(
        f"Must provide the same number of `transformer_layers_per_block` as `down_block_types`. "
        f"`transformer_layers_per_block`: {transformer_layers_per_block}. `down_block_types`: {down_block_types}."
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions