Skip to content

feat: Add Motif-Video model and pipelines#13551

Open
waitingcheung wants to merge 20 commits intohuggingface:mainfrom
waitingcheung:feat/motif-video
Open

feat: Add Motif-Video model and pipelines#13551
waitingcheung wants to merge 20 commits intohuggingface:mainfrom
waitingcheung:feat/motif-video

Conversation

@waitingcheung
Copy link
Copy Markdown

@waitingcheung waitingcheung commented Apr 23, 2026

What does this PR do?

This PR adds support for Motif-Video - a text-to-video (T2V) and image-to-video (I2V) diffusion model from Motif Technologies. The implementation includes the transformer architecture, both pipeline variants, guiding configurations, and comprehensive documentation.

Changes

New Files

  • Model: src/diffusers/models/transformers/transformer_motif_video.py - MotifVideoTransformer3DModel
  • Pipelines:
    • src/diffusers/pipelines/motif_video/pipeline_motif_video.py - Text-to-Video
    • src/diffusers/pipelines/motif_video/pipeline_motif_video_image2video.py - Image-to-Video
  • Output: src/diffusers/pipelines/motif_video/pipeline_output.py
  • Tests:
    • tests/pipelines/motif_video/test_motif_video.py
    • tests/pipelines/motif_video/test_motif_video_image2video.py
  • Documentation:
    • docs/source/en/api/models/motif_video_transformer_3d.md
    • docs/source/en/api/pipelines/motif_video.md

Key Features

  • Architecture: DiT-based transformer with T5Gemma2Encoder for text encoding
  • Flow Match: Uses FlowMatchEulerDiscreteScheduler
  • Guiding: Supports ClassifierFreeGuidance, SkipLayerGuidance, and AdaptiveProjectedGuidance
  • Video Processing: Wan-style VAE for video encoding/decoding

Version Requirements

  • transformers>=5.1.0 - Required for T5Gemma2Encoder (critical bug fix in PR #43633)
  • The pipeline includes a version check that raises a clear error with upgrade instructions if the transformers version is too old

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions Bot added documentation Improvements or additions to documentation models tests utils pipelines guiders size/L PR with diff > 200 LOC labels Apr 23, 2026
@waitingcheung waitingcheung changed the title Add Motif Video model and pipelines Add Motif-Video model and pipelines Apr 23, 2026
@waitingcheung
Copy link
Copy Markdown
Author

waitingcheung commented Apr 23, 2026

@yiyixuxu @asomoza @sayakpaul

Quick ping for visibility. This PR adds Motif-Video (T2V/I2V + new transformer and pipelines).

Would appreciate your feedback, especially on dependency/version constraints:

  • transformers>=5.1.0 for T5Gemma2Encoder (currently enforced via an assertion with an upgrade message)
  • compel requiring transformers<5, which may conflict with diffusers usage

This is currently blocking some diffusers-side integration, so your input would help.

A working branch for this integration is available here.

@waitingcheung waitingcheung marked this pull request as ready for review April 23, 2026 06:07
@waitingcheung waitingcheung changed the title Add Motif-Video model and pipelines feat: Add Motif-Video model and pipelines Apr 23, 2026
…dance support

Add complete Motif Video implementation to diffusers:

New Models:
- Add MotifVideoTransformer3DModel with T5Gemma2Encoder for multimodal conditioning
- Supports text-to-video and image-to-video generation with vision tower integration

New Pipelines:
- Add MotifVideoPipeline for text-to-video generation
  - Default resolution: 736x1280, 121 frames, 25 fps
  - Supports classifier-free guidance and AdaptiveProjectedGuidance
- Add MotifVideoImage2VideoPipeline for image-to-video generation
  - First frame conditioning with vision encoder
  - Same defaults as T2V pipeline

Enhanced Guidance:
- Update AdaptiveProjectedGuidance with normalization_dims parameter
  - Support "spatial" normalization for 5D tensors (per-frame spatial normalization)
  - Support custom dimension lists for flexible normalization
  - Update AdaptiveProjectedMixGuidance with same parameter

Documentation & Tests:
- Add comprehensive API documentation for transformer and pipelines
- Add test suites for both T2V and I2V pipelines
- Register all new components in __init__ files
- Add dummy objects for torch and transformers backends

Total: 18 files changed, 3416 insertions(+), 2 deletions(-)
@github-actions github-actions Bot added single-file size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 23, 2026
@sayakpaul
Copy link
Copy Markdown
Member

transformers>=5.1.0 for T5Gemma2Encoder (currently enforced via an assertion with an upgrade message)

I think we can guard the transformers import in the pipeline with something like is_transformers_version("<", "5.1.0")?

compel conflict is fine IMO.

@sayakpaul sayakpaul requested review from dg845 and yiyixuxu April 23, 2026 10:25
@waitingcheung
Copy link
Copy Markdown
Author

transformers>=5.1.0 for T5Gemma2Encoder (currently enforced via an assertion with an upgrade message)

I think we can guard the transformers import in the pipeline with something like is_transformers_version("<", "5.1.0")?

compel conflict is fine IMO.

We have something like this at the top of the pipeline code to guide the users to upgrade the transformers package before importing T5Gemma2Encoder

# Check transformers version before importing T5Gemma2Encoder
if not is_transformers_version(">=", "5.1.0"):
    import transformers

    raise ImportError(
        f"MotifVideoPipeline requires transformers>=5.1.0. "
        f"Found: {transformers.__version__}. "
        "Please upgrade transformers: pip install transformers --upgrade"
    )

@sayakpaul
Copy link
Copy Markdown
Member

Then it will cut it.

@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 27, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 28, 2026
@waitingcheung
Copy link
Copy Markdown
Author

@dg845 , @yiyixuxu
I would appreciate your feedback on this PR when you have a moment.

Ken Cheung and others added 4 commits April 29, 2026 05:33
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
- Use default Attention class with custom MotifVideoAttnProcessor2_0
- Inline cross-attention in transformer blocks
- Use dispatch_attention_fn for backend support
- Inherit AttentionMixin for attn_processors/set_attn_processor
- Move TransformerBlockRegistry to _helpers.py
- Add _repeated_blocks for regional compilation
@github-actions github-actions Bot added hooks size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@github-actions github-actions Bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 29, 2026
@waitingcheung
Copy link
Copy Markdown
Author

@dg845 Thank you again for the initial design review. I have addressed all your comments above.

In order to hold all of the attention parameters in an Attention-style class as suggested, I have converted the keys of the cross_attn weights with an attn prefix. You may need to download the transformer weights again in order to test the latest changes.
#13551 (comment)

I would be glad to address any further comments you have. Thank you.

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@claude can you do a review here?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 29, 2026

Claude finished @yiyixuxu's task in 3m 58s —— View job


Review of PR #13551: Motif-Video model and pipelines

  • Read repository guidelines (AGENTS.md, models.md, pipelines.md, modular.md)
  • Review transformer model (transformer_motif_video.py)
  • Review T2V pipeline (pipeline_motif_video.py)
  • Review I2V pipeline (pipeline_motif_video_image2video.py)
  • Review pipeline output and __init__ files
  • Review guider changes
  • Review hooks helper changes
  • Review single file model loader changes
  • Post review findings

Thanks for the contribution! The overall structure follows diffusers conventions well — the AttentionMixin/AttentionModuleMixin refactor, dispatch_attention_fn usage, hooks registration in _helpers.py, and the guider abstraction all look good. I have feedback organized by severity below.


Correctness / Bugs

1. torch.float64 in RoPE without device guardtransformer_motif_video.py:377

The MotifVideoRotaryPosEmbed.forward() uses unconditional torch.float64:

freqs_dtype=torch.float64,

Per the model conventions (models.md gotcha #6), MPS and NPU backends don't support float64. Every other transformer in the repo uses the device-gated pattern:

is_mps = hidden_states.device.type == "mps"
is_npu = hidden_states.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64

See transformer_flux.py:508, transformer_wan.py:375, etc. This will cause runtime errors on MPS devices.

2. Docstring defaults don't match signature defaultspipeline_motif_video.py:504-508, pipeline_motif_video_image2video.py:618-621

The docstrings say height defaults to 352, width defaults to 640, num_frames defaults to 65, but the actual function signatures use height=736, width=1280, num_frames=121. These should be consistent.

3. MotifVideoConditionEmbedding.forward reads weight dtypetransformer_motif_video.py:322-326

compute_dtype = next(
    (p.dtype for p in self.timestep_embedder.parameters() if p.is_floating_point()),
    torch.float32,
)
conditioning = self.timestep_embedder(timesteps_proj.to(compute_dtype))

Per models.md gotcha #5: "don't cast activations by reading a weight's dtype — the stored weight dtype isn't the compute dtype under gguf / quantized loading." Derive the cast target from the input tensor's dtype instead.

4. _current_sigma injection has no counterpart in guiderspipeline_motif_video.py:703-704

if hasattr(self.guider, "_current_sigma") and hasattr(self.scheduler, "sigmas"):
    self.guider._current_sigma = float(self.scheduler.sigmas[i])

No guider in the repo has a _current_sigma attribute, so this code is dead. If this is meant for a future APG feature, it shouldn't be here — remove it to avoid confusion. If it's needed for the normalization_dims changes in APG, that relationship should be established properly (and the APG changes should have tests).


Architecture / Pattern Issues

5. Inline cross-attention bypasses the attention processor systemtransformer_motif_video.py:471-503 and :585-616

Both MotifVideoSingleTransformerBlock.forward() and MotifVideoTransformerBlock.forward() call dispatch_attention_fn directly in the block's forward method for text cross-attention, rather than routing through an attention processor. This is inconsistent with every other model in the repo (Flux2, Cosmos, etc. all route through processors). This means:

  • Custom attention backends won't apply to the cross-attention path
  • The set_attn_processor API won't affect cross-attention
  • It breaks composability with attention hooks

Consider either: (a) incorporating the cross-attention into the processor, or (b) creating a separate MotifVideoAttention instance for cross-attention in each block and routing through the processor as usual.

6. _available_processors not set on MotifVideoAttentiontransformer_motif_video.py:162

Per the attention pattern in models.md, the attention class should declare _available_processors. Compare with Flux2Attention at transformer_flux2.py:495:

_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]

Add:

_available_processors = [MotifVideoAttnProcessor2_0]

7. Missing # Copied from annotations on shared methods in I2V pipelinepipeline_motif_video_image2video.py

Several methods are identical between T2V and I2V but lack # Copied from annotations, so they'll silently drift:

  • check_inputs (I2V version has slight differences — if intentional, that's fine)
  • _prepare_negative_prompt (~line 501)
  • _normalize_latents / _denormalize_latents (~lines 514-529)
  • prepare_latents (~line 531)

For the ones that are truly identical, add # Copied from diffusers.pipelines.motif_video.pipeline_motif_video.MotifVideoPipeline.<method> so make fix-copies keeps them in sync.

8. Redundant torch.no_grad() inside _prepare_first_frame_conditioningpipeline_motif_video_image2video.py:408,427

The method is only called from __call__, which is decorated with @torch.no_grad(). Per pipelines.md gotcha #2, inner with torch.no_grad(): is noise and forecloses callers who want grads enabled (e.g., training, embedding optimization).


Minor / Nits

9. frame_rate parameter is accepted but unusedpipeline_motif_video.py:474, pipeline_motif_video_image2video.py:589

Both pipelines accept frame_rate as a __call__ parameter but never use it. Either remove it or document why it exists (e.g., for metadata purposes). Unused parameters confuse users.

10. use_attention_mask parameterpipeline_motif_video.py:492

The T2V pipeline has a use_attention_mask parameter that conditionally omits the attention mask from guider inputs. The I2V pipeline doesn't have this parameter and always passes the mask. This inconsistency seems unintentional — the I2V pipeline should probably have the same parameter.

11. Single-file model loader mapping is a no-op lambdasingle_file_model.py:201

"MotifVideoTransformer3DModel": {
    "checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
    "default_subfolder": "transformer",
},

A no-op mapping function means the checkpoint is expected to already be in diffusers format. If this is intentional (i.e., no original format exists to convert from), it's fine but somewhat unusual. If there's an original checkpoint format, a proper conversion function should be added.

12. Guider normalization_dims changesadaptive_projected_guidance.py, adaptive_projected_guidance_mix.py

The new normalization_dims parameter in both APG guiders is a reasonable addition, but it's completely unrelated to the Motif-Video model/pipeline integration. It should be split into a separate PR so it can be reviewed and tested independently. The Motif-Video pipeline doesn't appear to use this parameter directly.

13. Formatting-only changes in single_file_model.py

Several hunks in this file are purely reformatting (line wrapping import statements, function arguments). These add noise to the diff and should ideally be separate or omitted.


Summary

The core model architecture and pipeline structure are solid and follow diffusers conventions. The main areas that need attention are:

  1. Must fix: torch.float64 device guard (will crash on MPS)
  2. Must fix: Docstring defaults vs actual defaults mismatch
  3. Should fix: Inline cross-attention should go through the processor system
  4. Should fix: _available_processors declaration missing
  5. Should fix: Missing # Copied from annotations on shared methods
  6. Cleanup: Remove dead _current_sigma code, redundant torch.no_grad(), unused frame_rate parameter
  7. Scope: Consider splitting the APG normalization_dims changes into a separate PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation guiders hooks models pipelines single-file size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants