Skip to content

Commit d96121b

Browse files
committed
[MAX] Add Wan T2V diffusion pipeline with MoE support
## Summary Add the Wan text-to-video (T2V) diffusion pipeline with MoE (Mixture of Experts) dual-transformer support. ## Description - Implements the full Wan T2V pipeline: text encoding → latent preparation → denoising loop → VAE decode - Supports **Wan 2.2 MoE models** (A14B) with dual transformers: high-noise expert for early steps, low-noise expert for later steps, with configurable boundary timestep - Supports **Wan 2.1 single-transformer models** (14B) with the same code path - LoRA support with automatic download from HuggingFace (e.g. Lightning turbo LoRAs for 4-step generation) - Classifier-free guidance with batched forward pass (positive + negative in one call) - On-device UniPC scheduler steps via compiled graphs — no Python-side numpy during denoising - Architecture registration for `Wan-AI/Wan2.2-T2V-A14B-Diffusers`, `Wan-AI/Wan2.1-T2V-14B-Diffusers`, etc. - Adds `guidance_scale_2` field to `VideoProviderOptions` for MoE boundary guidance control - Minimal upstream changes: only `_weight_paths` storage in `DiffusionPipeline.__init__` and Wan registration in `pixel_tokenizer.py` / `registry.py` ## Dependencies Depends on all previous PRs: modular#6298 (scheduler), modular#6299 (UMT5), modular#6300 (VAE), modular#6301 (transformer). ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code
1 parent 15eb852 commit d96121b

11 files changed

Lines changed: 1869 additions & 26 deletions

File tree

max/kernels/src/nn/conv/conv.mojo

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5280,12 +5280,13 @@ def _conv3d_cudnn[
52805280
var algo: cudnnConvolutionFwdAlgo_t
52815281
var workspace_size_var: Int
52825282

5283-
if ptr_cached := _get_global_or_null(cache_key).bitcast[
5284-
_Conv3dAlgoCacheEntry
5285-
]():
5283+
if ptr_cached := _get_global_or_null(cache_key):
5284+
var cached = ptr_cached.unsafe_value().bitcast[
5285+
_Conv3dAlgoCacheEntry
5286+
]()
52865287
# Cache hit — reuse previously selected algorithm.
5287-
algo = ptr_cached[].algo()
5288-
workspace_size_var = ptr_cached[].workspace_size
5288+
algo = cached[].algo()
5289+
workspace_size_var = cached[].workspace_size
52895290
else:
52905291
# Cache miss — run FindEx to find the fastest algorithm.
52915292
var find_ws = ctx.enqueue_create_buffer[DType.uint8](FIND_WS_CAP)

max/python/max/interfaces/provider_options/modality/video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,9 @@ class VideoProviderOptions(BaseModel):
6666
),
6767
gt=0,
6868
)
69+
70+
guidance_scale_2: float | None = Field(
71+
None,
72+
description="Secondary guidance scale for boundary timestep switching.",
73+
gt=0.0,
74+
)

max/python/max/pipelines/architectures/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def register_all_models() -> None:
8282
from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch
8383
from .unified_eagle_llama3 import unified_eagle_llama3_arch
8484
from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch
85+
from .wan import wan_arch, wan_i2v_arch
8586
from .z_image_modulev3 import z_image_arch
8687

8788
architectures = [
@@ -137,6 +138,8 @@ def register_all_models() -> None:
137138
qwen3vl_moe_arch,
138139
unified_eagle_llama3_arch,
139140
unified_mtp_deepseekV3_arch,
141+
wan_arch,
142+
wan_i2v_arch,
140143
z_image_arch,
141144
]
142145

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from .arch import wan_arch, wan_i2v_arch
15+
from .model import WanTransformerModel
16+
from .pipeline_wan import WanPipeline
17+
from .pipeline_wan_i2v import WanI2VPipeline
18+
19+
__all__ = [
20+
"WanI2VPipeline",
21+
"WanPipeline",
22+
"WanTransformerModel",
23+
"wan_arch",
24+
"wan_i2v_arch",
25+
]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from __future__ import annotations
15+
16+
from dataclasses import dataclass
17+
18+
from max.graph.weights import WeightsFormat
19+
from max.interfaces import PipelineTask
20+
from max.pipelines.core import PixelContext
21+
from max.pipelines.lib import (
22+
PixelGenerationTokenizer,
23+
SupportedArchitecture,
24+
)
25+
from max.pipelines.lib.config import MAXModelConfig, PipelineConfig
26+
from max.pipelines.lib.interfaces import ArchConfig
27+
from typing_extensions import Self
28+
29+
from .pipeline_wan import WanPipeline
30+
from .pipeline_wan_i2v import WanI2VPipeline
31+
32+
33+
@dataclass(kw_only=True)
34+
class WanArchConfig(ArchConfig):
35+
"""Pipeline-level config for Wan (implements ArchConfig; no KV cache)."""
36+
37+
pipeline_config: PipelineConfig
38+
39+
def get_max_seq_len(self) -> int:
40+
# Tokenizer padding length — matches diffusers __call__ default.
41+
return 512
42+
43+
@classmethod
44+
def initialize(
45+
cls,
46+
pipeline_config: PipelineConfig,
47+
model_config: MAXModelConfig | None = None,
48+
) -> Self:
49+
model_config = model_config or pipeline_config.model
50+
if len(model_config.device_specs) != 1:
51+
raise ValueError("Wan is only supported on a single device")
52+
return cls(pipeline_config=pipeline_config)
53+
54+
55+
wan_arch = SupportedArchitecture(
56+
name="WanPipeline",
57+
task=PipelineTask.PIXEL_GENERATION,
58+
default_encoding="bfloat16",
59+
supported_encodings={"bfloat16", "float32"},
60+
example_repo_ids=[
61+
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
62+
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
63+
"Wan-AI/Wan2.2-TI2V-5B-Diffusers",
64+
"yetter-ai/Wan2.2-TI2V-5B-Turbo-Diffusers",
65+
],
66+
pipeline_model=WanPipeline, # type: ignore[arg-type]
67+
context_type=PixelContext,
68+
default_weights_format=WeightsFormat.safetensors,
69+
tokenizer=PixelGenerationTokenizer,
70+
config=WanArchConfig,
71+
)
72+
73+
wan_i2v_arch = SupportedArchitecture(
74+
name="WanImageToVideoPipeline",
75+
task=PipelineTask.PIXEL_GENERATION,
76+
default_encoding="bfloat16",
77+
supported_encodings={"bfloat16", "float32"},
78+
example_repo_ids=[
79+
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
80+
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
81+
],
82+
pipeline_model=WanI2VPipeline,
83+
context_type=PixelContext,
84+
default_weights_format=WeightsFormat.safetensors,
85+
tokenizer=PixelGenerationTokenizer,
86+
config=WanArchConfig,
87+
)

0 commit comments

Comments
 (0)