Skip to content

easyanimate model/pipeline review #13638

@hlky

Description

@hlky

easyanimate model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search: checked GitHub Issues/PRs for EasyAnimate, affected class/function names, and the specific failure modes. No exact duplicates found. Related: #12646 reports another crash in the same inpaint repaint branch; #13347 only refactors transformer tests.

Local test note: .venv was used. Focused Python reproductions ran; full non-slow pytest collection is blocked in this .venv by ModuleNotFoundError: torch._C._distributed_c10d.

Issue 1: EasyAnimateControlPipeline.__call__ always passes an invalid encode_prompt kwarg

Affected code:

) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
text_encoder_index=0,
)

Problem:
EasyAnimateControlPipeline.encode_prompt() does not accept text_encoder_index, but __call__ passes text_encoder_index=0. Any control pipeline call reaches this TypeError before denoising.

Impact:
EasyAnimateControlPipeline is effectively unusable through its public __call__.

Reproduction:

from diffusers import EasyAnimateControlPipeline

pipe = object.__new__(EasyAnimateControlPipeline)
EasyAnimateControlPipeline.encode_prompt(pipe, prompt="x", text_encoder_index=0)

Relevant precedent:
Base and inpaint EasyAnimate call encode_prompt without this kwarg:

) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)

Suggested fix:

# Remove the stale kwarg from EasyAnimateControlPipeline.__call__
negative_prompt_attention_mask=negative_prompt_attention_mask,

Issue 2: EasyAnimateControlPipeline decodes through an undefined method

Affected code:

# Convert to tensor
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

Problem:
The control pipeline calls self.decode_latents(latents), but the class does not define decode_latents.

Impact:
After fixing Issue 1, any control run with output_type != "latent" will fail at decode time.

Reproduction:

from diffusers import EasyAnimateControlPipeline

pipe = object.__new__(EasyAnimateControlPipeline)
print(hasattr(pipe, "decode_latents"))
pipe.decode_latents(None)

Relevant precedent:

if not output_type == "latent":
latents = 1 / self.vae.config.scaling_factor * latents
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

Suggested fix:

latents = 1 / self.vae.config.scaling_factor * latents
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

Issue 3: Control helper optional mask/reference paths are broken

Affected code:

def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
if input_video is not None:
# Convert each frame in the list to tensor
input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
# Stack all frames into a single tensor (F, C, H, W)
input_video = torch.stack(input_video)[:num_frames]
# Add batch dimension (B, F, C, H, W)
input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
if validation_video_mask is not None:
# Handle mask input
validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
# Adjust mask dimensions to match video
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
else:
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, :] = 255
else:
input_video, input_video_mask = None, None
if ref_image is not None:
# Convert reference image to tensor
ref_image = preprocess_image(ref_image, size=sample_size)
ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)

Problem:
get_video_to_video_latent() calls preprocess_image(..., size=sample_size), but the helper parameter is named sample_size. The ref_image branch also builds the wrong rank for the pipeline, which later expects (B, C, F, H, W).

Impact:
Documented/public helper paths for control masks and reference images fail before pipeline execution.

Reproduction:

from PIL import Image
from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent

frame = Image.new("RGB", (8, 8), "white")
mask = Image.new("RGB", (8, 8), "black")
ref = Image.new("RGB", (8, 8), "blue")

for kwargs in ({"validation_video_mask": mask}, {"ref_image": ref}):
    try:
        get_video_to_video_latent([frame], 1, (8, 8), **kwargs)
    except Exception as e:
        print(type(e).__name__, e)

Relevant precedent:
The main video path already uses sample_size=sample_size at line 123.

Suggested fix:

validation_video_mask = preprocess_image(validation_video_mask, sample_size=sample_size)[:1]
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255.0)
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(2).repeat(1, 1, input_video.shape[2], 1, 1)

ref_image = preprocess_image(ref_image, sample_size=sample_size)
ref_image = ref_image.unsqueeze(1).unsqueeze(0)

Issue 4: Inpaint helper fails for multiple end frames

Affected code:

if validation_image_end is not None:
if isinstance(validation_image_end, list):
image_end = [preprocess_image(img, sample_size) for img in validation_image_end]
end_video = torch.cat(
[img.unsqueeze(1).unsqueeze(0) for img in image_end],
dim=2,
)
input_video[:, :, -len(end_video) :] = end_video
input_video_mask[:, :, -len(image_end) :] = 0

Problem:
For validation_image_end as a list, the code slices with len(end_video), which is the batch dimension (1), not the number of end frames.

Impact:
Multi-frame end conditioning raises a shape error and cannot prepare inputs.

Reproduction:

from PIL import Image
from diffusers.pipelines.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent

start = Image.new("RGB", (8, 8), "white")
ends = [Image.new("RGB", (8, 8), "black"), Image.new("RGB", (8, 8), "blue")]
get_image_to_video_latent(start, ends, 4, (8, 8))

Relevant precedent:
The mask branch already uses len(image_end) on the next line.

Suggested fix:

input_video[:, :, -len(image_end) :] = end_video
input_video_mask[:, :, -len(image_end) :] = 0

Issue 5: Inpaint FlowMatch repaint branch has a malformed torch.tensor call

Affected code:

if num_channels_transformer == num_channels_latents:
init_latents_proper = image_latents
init_mask = mask
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper, torch.tensor([noise_timestep], noise)
)

Problem:
torch.tensor([noise_timestep], noise) passes noise as a second positional argument to torch.tensor, which is invalid. It also fails to pass noise to scale_noise.

Impact:
The repaint branch for FlowMatchEulerDiscreteScheduler crashes when num_channels_transformer == num_channels_latents.

Reproduction:

import torch
from diffusers import FlowMatchEulerDiscreteScheduler

scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(2)

sample = torch.zeros(1, 4, 1, 2, 2)
noise = torch.ones_like(sample)
noise_timestep = scheduler.timesteps[1]

scheduler.scale_noise(sample, torch.tensor([noise_timestep], noise))

Relevant precedent:

if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper, torch.tensor([noise_timestep]), noise
)

Suggested fix:

init_latents_proper = self.scheduler.scale_noise(
    init_latents_proper, torch.tensor([noise_timestep], device=device), noise
)

Issue 6: EasyAnimate attention ignores Diffusers attention backends

Affected code:

class EasyAnimateAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the EasyAnimateTransformer3DModel model.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=2)
key = torch.cat([encoder_key, key], dim=2)
value = torch.cat([encoder_value, value], dim=2)
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False

Problem:
EasyAnimateAttnProcessor2_0 calls F.scaled_dot_product_attention directly and does not define _attention_backend / _parallel_config. model.set_attention_backend(...) silently skips the processor.

Impact:
Users cannot select supported attention backends for EasyAnimate, and context-parallel/backend plumbing cannot affect this model.

Reproduction:

from diffusers import EasyAnimateTransformer3DModel

model = EasyAnimateTransformer3DModel(
    num_attention_heads=2,
    attention_head_dim=16,
    in_channels=4,
    out_channels=4,
    time_embed_dim=8,
    text_embed_dim=16,
    num_layers=1,
    mmdit_layers=1,
    patch_size=2,
)

processor = model.transformer_blocks[0].attn1.processor
print(hasattr(processor, "_attention_backend"))
model.set_attention_backend("native")
print(getattr(processor, "_attention_backend", None))

Relevant precedent:

class FluxAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

Suggested fix:

from ..attention_dispatch import dispatch_attention_fn

class EasyAnimateAttnProcessor2_0:
    _attention_backend = None
    _parallel_config = None

    ...
    hidden_states = dispatch_attention_fn(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=0.0,
        is_causal=False,
        backend=self._attention_backend,
        parallel_config=self._parallel_config,
    )

Issue 7: Control and inpaint variants have no fast or slow pipeline coverage

Affected code:

class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = EasyAnimatePipeline

@slow
@require_torch_accelerator
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_EasyAnimate(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = self.prompt
videos = pipe(
prompt=prompt,
height=480,
width=720,
num_frames=5,
generator=generator,
num_inference_steps=2,
output_type="pt",
).frames
video = videos[0]
expected_video = torch.randn(1, 5, 480, 720, 3).numpy()
max_diff = numpy_cosine_similarity_distance(video, expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

## EasyAnimatePipeline
[[autodoc]] EasyAnimatePipeline
- all
- __call__
## EasyAnimatePipelineOutput
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput

Problem:
The test file only covers EasyAnimatePipeline; there are no fast tests or slow/integration tests for EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, or their public helper functions. Docs also only autodoc the base pipeline.

Impact:
The concrete control/inpaint regressions above are not caught by CI, and public variants are less discoverable.

Reproduction:

from pathlib import Path

tests = Path("tests/pipelines/easyanimate/test_easyanimate.py").read_text()
docs = Path("docs/source/en/api/pipelines/easyanimate.md").read_text()

print("EasyAnimateControlPipeline" in tests)
print("EasyAnimateInpaintPipeline" in tests)
print("[[autodoc]] EasyAnimateControlPipeline" in docs)
print("[[autodoc]] EasyAnimateInpaintPipeline" in docs)

Relevant precedent:
Existing base pipeline fast and slow tests are in the same file and can be extended with tiny control/inpaint fixtures.

Suggested fix:
Add focused fast tests for control and inpaint using tiny components, including output_type="pt" and "latent" paths plus helper utility tests. Add slow tests for official control and inpaint checkpoints, or explicitly mark/document why they cannot be run. Update the EasyAnimate docs to autodoc EasyAnimateControlPipeline and EasyAnimateInpaintPipeline.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions