Skip to content

Commit 796f5b5

Browse files
authored
Add wan 2.1 model (#1409)
1 parent e52c128 commit 796f5b5

24 files changed

Lines changed: 3411 additions & 0 deletions

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ package for LLMs with MLX.
2626
- Image classification using [ResNets on CIFAR-10](cifar).
2727
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
2828

29+
### Video Models
30+
31+
- Text-to-video and image-to-video generation with [Wan2.1](video/wan2.1).
32+
2933
### Audio Models
3034

3135
- Speech recognition with [OpenAI's Whisper](whisper).

video/wan2.1/README.md

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
Wan2.1
2+
======
3+
4+
Wan2.1 text-to-video and image-to-video implementation in MLX. The model
5+
weights are downloaded directly from the [Hugging Face
6+
Hub](https://huggingface.co/Wan-AI).
7+
8+
| Model | Task | HF Repo | RAM (unquantized), 81 frames | Single DiT step on M4 Max chip, 81 frames |
9+
|-------|------|---------|-----------------|---|
10+
| 1.3B | T2V | [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | ~10GB | ~90 s/it |
11+
| 14B | T2V | [Wan-AI/Wan2.1-T2V-14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | ~36GB | ~230 s/it |
12+
| 14B | I2V | [Wan-AI/Wan2.1-I2V-14B-480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | ~39GB | ~250 s/it |
13+
14+
| T2V 1.3B | T2V 14B | I2V 14B |
15+
|---|---|---|
16+
| ![WAN t2v 1.3B](static/out_t2v_1_3b.gif) |![WAN t2v 14B distilled](static/out_t2v_cats.gif) | ![WAN t2v 14B distilled](static/out_i2v_astronaut.gif) |
17+
| Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | An astronaut riding a horse |
18+
19+
Installation
20+
------------
21+
22+
Install the dependencies:
23+
```shell
24+
pip install -r requirements.txt
25+
```
26+
27+
Saving videos requires [ffmpeg](https://ffmpeg.org/) on your PATH.
28+
29+
Usage
30+
-----
31+
32+
### Text-to-Video
33+
34+
Generate a video with the default 1.3B model:
35+
36+
```shell
37+
python txt2video.py 'A cat playing piano' --output out.mp4
38+
```
39+
40+
Use the 14B model with quantization:
41+
42+
```shell
43+
python txt2video.py 'A cat playing piano' \
44+
--model t2v-14B --quantize --output out_14B.mp4
45+
```
46+
47+
Adjust resolution, frame count, and sampling parameters:
48+
49+
```shell
50+
python txt2video.py 'Ocean waves crashing on a rocky shore at sunset' \
51+
--size 832x480 --frames 81 --steps 50 --guidance 5.0 --seed 42 \
52+
--output waves.mp4
53+
```
54+
55+
For more parameters, use `python txt2video.py --help`.
56+
57+
### Image-to-Video
58+
59+
Generate a video from an input image:
60+
61+
```shell
62+
python img2video.py 'Astronaut riding a horse' \
63+
--image ./inputs/astronaut-on-a-horse.png --quantize --output out_i2v.mp4
64+
```
65+
66+
Adjust resolution and sampling parameters:
67+
68+
```shell
69+
python img2video.py 'Astronaut riding a horse' \
70+
--image ./inputs/astronaut-on-a-horse.png --size 832x480 --frames 81 --steps 40 \
71+
--guidance 5.0 --shift 3.0 --seed 42 --output out_i2v.mp4
72+
```
73+
74+
For more parameters, use `python img2video.py --help`.
75+
76+
### Quantization
77+
78+
Pass `--quantize` (or `-q`) to the CLI
79+
80+
```shell
81+
python txt2video.py 'A cat playing piano' --quantize --output out_quantized.mp4
82+
```
83+
84+
### Disabling the cache
85+
To get additional memory savings at the expense of a bit of speed use `--no-cache` argument. It will prevent MLX from utilizing the cache (sets `mx.set_cache_limit(0)` under the hood). See [documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.set_cache_limit.html) for more info
86+
```shell
87+
python txt2video.py 'A cat playing piano' --output out.mp4 --no-cache
88+
```
89+
90+
For 1.3B model 480p 81 frames `--no-cache` run utilizes ~10GB of RAM and ~14GB of RAM otherwise
91+
92+
### Custom DiT Weights
93+
94+
Use `--checkpoint` to load custom DiT weights (e.g. [step-distilled models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)).
95+
Pass `--sampler euler` to use Euler sampling for step-distilled models:
96+
97+
For text to video pipeline you can try [this 4 steps distilled model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors)
98+
99+
```shell
100+
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
101+
```
102+
103+
```shell
104+
python txt2video.py 'A cat playing piano' \
105+
--model t2v-14B --checkpoint ./wan2.1_t2v_14b_lightx2v_4step.safetensors \
106+
--sampler euler --steps 4 --guidance 1.0 \
107+
--quantize --output out_t2v_distilled.mp4
108+
```
109+
110+
For image to video pipeline we use [4 steps distilled i2v model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors)
111+
112+
```shell
113+
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors
114+
```
115+
116+
```shell
117+
python img2video.py 'Astronaut riding a horse' \
118+
--image ./inputs/astronaut-on-a-horse.png --checkpoint ./wan2.1_i2v_480p_lightx2v_4step.safetensors \
119+
--sampler euler --steps 4 --guidance 1.0 --shift 5.0 \
120+
--quantize --output out_i2v_distilled.mp4
121+
```
122+
123+
### Options
124+
125+
- **Negative prompts**: `--n-prompt 'blurry, low quality, distorted'`
126+
- **Disable CFG**: `--guidance 1.0` skips the unconditional pass, roughly
127+
halving compute per step.
128+
129+
### TeaCache
130+
131+
[TeaCache](https://arxiv.org/abs/2411.19108) skips redundant transformer computations when consecutive steps
132+
produce similar embeddings, eliminating 20-60% of forward passes. Note that the TeaCache parameters are calibrated for each resolution, consult with [LightX2V](https://github.com/ModelTC/LightX2V/tree/main/configs/caching) configs for advanced tweaking. Our defaults are located at [pipeline.py](./wan/pipeline.py#20)
133+
134+
```shell
135+
python txt2video.py 'A cat playing piano' --teacache 0.05 --output out.mp4 --verbose
136+
```
137+
138+
Recommended thresholds (1.3B):
139+
140+
| Threshold | Skip Rate | Quality |
141+
|-----------|-----------|---------|
142+
| `0.05` | ~34% | Almost lossless |
143+
| `0.1` | ~58% | Slightly corrupted |
144+
| `0.25` | ~76% | Visible quality loss |
145+
146+
#### Result with --teacache for 1.3B model
147+
`Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.`
148+
|`--teacache 0.05`, 34% steps skipped (17/50) |`--teacache 0.1`, 58% steps skipped (29/50) |`--teacache 0.25`, 76% steps skipped (38/50) |
149+
|---|---|---|
150+
|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_005.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_01.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_025.gif)|
151+
152+
# References
153+
1. [Original WAN 2.1 implementation](https://github.com/Wan-Video/Wan2.1)
154+
2. [LightX2V](https://github.com/ModelTC/LightX2V)

video/wan2.1/img2video.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright © 2026 Apple Inc.
2+
3+
"""Generate videos from an image and text prompt using Wan2.1 I2V."""
4+
5+
import argparse
6+
import logging
7+
8+
import mlx.core as mx
9+
import mlx.nn as nn
10+
from tqdm import tqdm
11+
from wan import WanPipeline
12+
from wan.utils import save_video
13+
14+
if __name__ == "__main__":
15+
parser = argparse.ArgumentParser(
16+
description="Generate videos from an image and text prompt using Wan2.1 I2V"
17+
)
18+
parser.add_argument("prompt")
19+
parser.add_argument("--image", required=True, help="Path to input image")
20+
parser.add_argument("--model", choices=["i2v-14B"], default="i2v-14B")
21+
parser.add_argument(
22+
"--size",
23+
type=lambda x: tuple(map(int, x.split("x"))),
24+
default=(832, 480),
25+
help="Video size as WxH (default: 832x480)",
26+
)
27+
parser.add_argument("--frames", type=int, default=81)
28+
parser.add_argument(
29+
"--steps", type=int, default=40, help="Number of denoising steps"
30+
)
31+
parser.add_argument("--guidance", type=float, default=5.0)
32+
parser.add_argument("--shift", type=float, default=3.0)
33+
parser.add_argument("--seed", type=int)
34+
parser.add_argument(
35+
"--quantize",
36+
"-q",
37+
type=int,
38+
nargs="?",
39+
const=8,
40+
default=0,
41+
choices=[0, 4, 8],
42+
metavar="{4,8}",
43+
help="Quantize DiT weights (default: 8-bit when flag used without value)",
44+
)
45+
parser.add_argument(
46+
"--n-prompt",
47+
default="Text, watermarks, blurry image, JPEG artifacts",
48+
)
49+
parser.add_argument(
50+
"--teacache",
51+
type=float,
52+
default=0.0,
53+
help="TeaCache threshold for step skipping (0=off, 0.26=recommended for i2v)",
54+
)
55+
parser.add_argument(
56+
"--checkpoint",
57+
type=str,
58+
default=None,
59+
help="Path to custom DiT weights (.safetensors), e.g. distilled models",
60+
)
61+
parser.add_argument(
62+
"--sampler",
63+
choices=["unipc", "euler"],
64+
default="unipc",
65+
help="Sampler: unipc (default) or euler (for step-distilled models)",
66+
)
67+
parser.add_argument("--output", default="out.mp4")
68+
parser.add_argument("--preload-models", action="store_true")
69+
parser.add_argument(
70+
"--no-cache",
71+
action="store_true",
72+
help="Disable Metal buffer cache (mx.set_cache_limit(0)) to reduce swap pressure",
73+
)
74+
parser.add_argument("--verbose", "-v", action="store_true")
75+
args = parser.parse_args()
76+
77+
if args.sampler == "euler":
78+
# Evenly spaced steps: e.g. 4 steps -> [1000, 750, 500, 250]
79+
n = args.steps
80+
denoising_step_list = [1000 * i // n for i in range(n, 0, -1)]
81+
else:
82+
denoising_step_list = None
83+
84+
mx.set_default_device(mx.gpu)
85+
if args.no_cache:
86+
mx.set_cache_limit(0)
87+
88+
if args.verbose:
89+
handler = logging.StreamHandler()
90+
handler.setFormatter(logging.Formatter("%(message)s"))
91+
logging.getLogger("wan").setLevel(logging.INFO)
92+
logging.getLogger("wan").addHandler(handler)
93+
94+
# Load pipeline
95+
pipeline = WanPipeline(args.model, checkpoint=args.checkpoint)
96+
97+
# Quantize DiT
98+
if args.quantize:
99+
nn.quantize(pipeline.flow, bits=args.quantize)
100+
print(f"Quantized DiT to {args.quantize}-bit")
101+
102+
if args.preload_models:
103+
pipeline.ensure_models_are_loaded()
104+
105+
# Generate latents (generator pattern)
106+
latents = pipeline.generate_latents(
107+
args.prompt,
108+
image_path=args.image,
109+
negative_prompt=args.n_prompt,
110+
size=args.size,
111+
frame_num=args.frames,
112+
num_steps=args.steps,
113+
guidance=args.guidance,
114+
shift=args.shift,
115+
seed=args.seed,
116+
teacache=args.teacache,
117+
verbose=args.verbose,
118+
denoising_step_list=denoising_step_list,
119+
)
120+
121+
# 1. Conditioning
122+
conditioning = next(latents)
123+
mx.eval(conditioning)
124+
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
125+
mx.reset_peak_memory()
126+
127+
# Free T5 and CLIP memory
128+
del pipeline.t5
129+
if pipeline.clip is not None:
130+
del pipeline.clip
131+
mx.clear_cache()
132+
133+
# 2. Denoising loop
134+
for x_t in tqdm(latents, total=args.steps):
135+
mx.eval(x_t)
136+
137+
# Free DiT memory
138+
del pipeline.flow
139+
mx.clear_cache()
140+
peak_mem_generation = mx.get_peak_memory() / 1024**3
141+
mx.reset_peak_memory()
142+
143+
# 3. VAE decode
144+
video = pipeline.decode(x_t)
145+
mx.eval(video)
146+
peak_mem_decoding = mx.get_peak_memory() / 1024**3
147+
148+
# Save video
149+
save_video(video, args.output)
150+
151+
if args.verbose:
152+
peak_mem_overall = max(
153+
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
154+
)
155+
print(f"Peak memory conditioning: {peak_mem_conditioning:.3f}GB")
156+
print(f"Peak memory generation: {peak_mem_generation:.3f}GB")
157+
print(f"Peak memory decoding: {peak_mem_decoding:.3f}GB")
158+
print(f"Peak memory overall: {peak_mem_overall:.3f}GB")
1.83 MB
Loading

video/wan2.1/requirements.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
einops>=0.8.2 # for mlx compatible einops
2+
huggingface_hub
3+
mlx>=0.31.0 # for conv3d memory and speed fix
4+
numpy
5+
Pillow
6+
tokenizers
7+
torch # for loading of huggingface weights
8+
tqdm
9.56 MB
Loading
6.4 MB
Loading
6.32 MB
Loading
6.3 MB
Loading
6.33 MB
Loading

0 commit comments

Comments
 (0)