diff --git a/README.md b/README.md index 9b88bd349..dcdf98330 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,10 @@ package for LLMs with MLX. - Image classification using [ResNets on CIFAR-10](cifar). - Convolutional variational autoencoder [(CVAE) on MNIST](cvae). +### Video Models + +- Text-to-video and image-to-video generation with [Wan2.1](video/wan2.1). + ### Audio Models - Speech recognition with [OpenAI's Whisper](whisper). diff --git a/video/wan2.1/README.md b/video/wan2.1/README.md new file mode 100644 index 000000000..68d4e1fba --- /dev/null +++ b/video/wan2.1/README.md @@ -0,0 +1,154 @@ +Wan2.1 +====== + +Wan2.1 text-to-video and image-to-video implementation in MLX. The model +weights are downloaded directly from the [Hugging Face +Hub](https://huggingface.co/Wan-AI). + +| Model | Task | HF Repo | RAM (unquantized), 81 frames | Single DiT step on M4 Max chip, 81 frames | +|-------|------|---------|-----------------|---| +| 1.3B | T2V | [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | ~10GB | ~90 s/it | +| 14B | T2V | [Wan-AI/Wan2.1-T2V-14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | ~36GB | ~230 s/it | +| 14B | I2V | [Wan-AI/Wan2.1-I2V-14B-480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | ~39GB | ~250 s/it | + +| T2V 1.3B | T2V 14B | I2V 14B | +|---|---|---| +| ![WAN t2v 1.3B](static/out_t2v_1_3b.gif) |![WAN t2v 14B distilled](static/out_t2v_cats.gif) | ![WAN t2v 14B distilled](static/out_i2v_astronaut.gif) | +| Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage. | An astronaut riding a horse | + +Installation +------------ + +Install the dependencies: +```shell +pip install -r requirements.txt +``` + +Saving videos requires [ffmpeg](https://ffmpeg.org/) on your PATH. + +Usage +----- + +### Text-to-Video + +Generate a video with the default 1.3B model: + +```shell +python txt2video.py 'A cat playing piano' --output out.mp4 +``` + +Use the 14B model with quantization: + +```shell +python txt2video.py 'A cat playing piano' \ + --model t2v-14B --quantize --output out_14B.mp4 +``` + +Adjust resolution, frame count, and sampling parameters: + +```shell +python txt2video.py 'Ocean waves crashing on a rocky shore at sunset' \ + --size 832x480 --frames 81 --steps 50 --guidance 5.0 --seed 42 \ + --output waves.mp4 +``` + +For more parameters, use `python txt2video.py --help`. + +### Image-to-Video + +Generate a video from an input image: + +```shell +python img2video.py 'Astronaut riding a horse' \ + --image ./inputs/astronaut-on-a-horse.png --quantize --output out_i2v.mp4 +``` + +Adjust resolution and sampling parameters: + +```shell +python img2video.py 'Astronaut riding a horse' \ + --image ./inputs/astronaut-on-a-horse.png --size 832x480 --frames 81 --steps 40 \ + --guidance 5.0 --shift 3.0 --seed 42 --output out_i2v.mp4 +``` + +For more parameters, use `python img2video.py --help`. + +### Quantization + +Pass `--quantize` (or `-q`) to the CLI + +```shell +python txt2video.py 'A cat playing piano' --quantize --output out_quantized.mp4 +``` + +### Disabling the cache +To get additional memory savings at the expense of a bit of speed use `--no-cache` argument. It will prevent MLX from utilizing the cache (sets `mx.set_cache_limit(0)` under the hood). See [documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.set_cache_limit.html) for more info +```shell +python txt2video.py 'A cat playing piano' --output out.mp4 --no-cache +``` + +For 1.3B model 480p 81 frames `--no-cache` run utilizes ~10GB of RAM and ~14GB of RAM otherwise + +### Custom DiT Weights + +Use `--checkpoint` to load custom DiT weights (e.g. [step-distilled models](https://huggingface.co/lightx2v/Wan2.1-Distill-Models)). +Pass `--sampler euler` to use Euler sampling for step-distilled models: + +For text to video pipeline you can try [this 4 steps distilled model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors) + +```shell +wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_t2v_14b_lightx2v_4step.safetensors +``` + +```shell +python txt2video.py 'A cat playing piano' \ + --model t2v-14B --checkpoint ./wan2.1_t2v_14b_lightx2v_4step.safetensors \ + --sampler euler --steps 4 --guidance 1.0 \ + --quantize --output out_t2v_distilled.mp4 +``` + +For image to video pipeline we use [4 steps distilled i2v model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors) + +```shell +wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_i2v_480p_lightx2v_4step.safetensors +``` + +```shell +python img2video.py 'Astronaut riding a horse' \ + --image ./inputs/astronaut-on-a-horse.png --checkpoint ./wan2.1_i2v_480p_lightx2v_4step.safetensors \ + --sampler euler --steps 4 --guidance 1.0 --shift 5.0 \ + --quantize --output out_i2v_distilled.mp4 +``` + +### Options + +- **Negative prompts**: `--n-prompt 'blurry, low quality, distorted'` +- **Disable CFG**: `--guidance 1.0` skips the unconditional pass, roughly + halving compute per step. + +### TeaCache + +[TeaCache](https://arxiv.org/abs/2411.19108) skips redundant transformer computations when consecutive steps +produce similar embeddings, eliminating 20-60% of forward passes. Note that the TeaCache parameters are calibrated for each resolution, consult with [LightX2V](https://github.com/ModelTC/LightX2V/tree/main/configs/caching) configs for advanced tweaking. Our defaults are located at [pipeline.py](./wan/pipeline.py#20) + +```shell +python txt2video.py 'A cat playing piano' --teacache 0.05 --output out.mp4 --verbose +``` + +Recommended thresholds (1.3B): + +| Threshold | Skip Rate | Quality | +|-----------|-----------|---------| +| `0.05` | ~34% | Almost lossless | +| `0.1` | ~58% | Slightly corrupted | +| `0.25` | ~76% | Visible quality loss | + +#### Result with --teacache for 1.3B model +`Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.` +|`--teacache 0.05`, 34% steps skipped (17/50) |`--teacache 0.1`, 58% steps skipped (29/50) |`--teacache 0.25`, 76% steps skipped (38/50) | +|---|---|---| +|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_005.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_01.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_025.gif)| + +# References +1. [Original WAN 2.1 implementation](https://github.com/Wan-Video/Wan2.1) +2. [LightX2V](https://github.com/ModelTC/LightX2V) diff --git a/video/wan2.1/img2video.py b/video/wan2.1/img2video.py new file mode 100644 index 000000000..c2557389b --- /dev/null +++ b/video/wan2.1/img2video.py @@ -0,0 +1,158 @@ +# Copyright © 2026 Apple Inc. + +"""Generate videos from an image and text prompt using Wan2.1 I2V.""" + +import argparse +import logging + +import mlx.core as mx +import mlx.nn as nn +from tqdm import tqdm +from wan import WanPipeline +from wan.utils import save_video + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate videos from an image and text prompt using Wan2.1 I2V" + ) + parser.add_argument("prompt") + parser.add_argument("--image", required=True, help="Path to input image") + parser.add_argument("--model", choices=["i2v-14B"], default="i2v-14B") + parser.add_argument( + "--size", + type=lambda x: tuple(map(int, x.split("x"))), + default=(832, 480), + help="Video size as WxH (default: 832x480)", + ) + parser.add_argument("--frames", type=int, default=81) + parser.add_argument( + "--steps", type=int, default=40, help="Number of denoising steps" + ) + parser.add_argument("--guidance", type=float, default=5.0) + parser.add_argument("--shift", type=float, default=3.0) + parser.add_argument("--seed", type=int) + parser.add_argument( + "--quantize", + "-q", + type=int, + nargs="?", + const=8, + default=0, + choices=[0, 4, 8], + metavar="{4,8}", + help="Quantize DiT weights (default: 8-bit when flag used without value)", + ) + parser.add_argument( + "--n-prompt", + default="Text, watermarks, blurry image, JPEG artifacts", + ) + parser.add_argument( + "--teacache", + type=float, + default=0.0, + help="TeaCache threshold for step skipping (0=off, 0.26=recommended for i2v)", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="Path to custom DiT weights (.safetensors), e.g. distilled models", + ) + parser.add_argument( + "--sampler", + choices=["unipc", "euler"], + default="unipc", + help="Sampler: unipc (default) or euler (for step-distilled models)", + ) + parser.add_argument("--output", default="out.mp4") + parser.add_argument("--preload-models", action="store_true") + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable Metal buffer cache (mx.set_cache_limit(0)) to reduce swap pressure", + ) + parser.add_argument("--verbose", "-v", action="store_true") + args = parser.parse_args() + + if args.sampler == "euler": + # Evenly spaced steps: e.g. 4 steps -> [1000, 750, 500, 250] + n = args.steps + denoising_step_list = [1000 * i // n for i in range(n, 0, -1)] + else: + denoising_step_list = None + + mx.set_default_device(mx.gpu) + if args.no_cache: + mx.set_cache_limit(0) + + if args.verbose: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + logging.getLogger("wan").setLevel(logging.INFO) + logging.getLogger("wan").addHandler(handler) + + # Load pipeline + pipeline = WanPipeline(args.model, checkpoint=args.checkpoint) + + # Quantize DiT + if args.quantize: + nn.quantize(pipeline.flow, bits=args.quantize) + print(f"Quantized DiT to {args.quantize}-bit") + + if args.preload_models: + pipeline.ensure_models_are_loaded() + + # Generate latents (generator pattern) + latents = pipeline.generate_latents( + args.prompt, + image_path=args.image, + negative_prompt=args.n_prompt, + size=args.size, + frame_num=args.frames, + num_steps=args.steps, + guidance=args.guidance, + shift=args.shift, + seed=args.seed, + teacache=args.teacache, + verbose=args.verbose, + denoising_step_list=denoising_step_list, + ) + + # 1. Conditioning + conditioning = next(latents) + mx.eval(conditioning) + peak_mem_conditioning = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() + + # Free T5 and CLIP memory + del pipeline.t5 + if pipeline.clip is not None: + del pipeline.clip + mx.clear_cache() + + # 2. Denoising loop + for x_t in tqdm(latents, total=args.steps): + mx.eval(x_t) + + # Free DiT memory + del pipeline.flow + mx.clear_cache() + peak_mem_generation = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() + + # 3. VAE decode + video = pipeline.decode(x_t) + mx.eval(video) + peak_mem_decoding = mx.get_peak_memory() / 1024**3 + + # Save video + save_video(video, args.output) + + if args.verbose: + peak_mem_overall = max( + peak_mem_conditioning, peak_mem_generation, peak_mem_decoding + ) + print(f"Peak memory conditioning: {peak_mem_conditioning:.3f}GB") + print(f"Peak memory generation: {peak_mem_generation:.3f}GB") + print(f"Peak memory decoding: {peak_mem_decoding:.3f}GB") + print(f"Peak memory overall: {peak_mem_overall:.3f}GB") diff --git a/video/wan2.1/inputs/astronaut-on-a-horse.png b/video/wan2.1/inputs/astronaut-on-a-horse.png new file mode 100644 index 000000000..521dbda0a Binary files /dev/null and b/video/wan2.1/inputs/astronaut-on-a-horse.png differ diff --git a/video/wan2.1/requirements.txt b/video/wan2.1/requirements.txt new file mode 100644 index 000000000..2dcf19fa9 --- /dev/null +++ b/video/wan2.1/requirements.txt @@ -0,0 +1,8 @@ +einops>=0.8.2 # for mlx compatible einops +huggingface_hub +mlx>=0.31.0 # for conv3d memory and speed fix +numpy +Pillow +tokenizers +torch # for loading of huggingface weights +tqdm diff --git a/video/wan2.1/static/out_i2v_astronaut.gif b/video/wan2.1/static/out_i2v_astronaut.gif new file mode 100644 index 000000000..4b0efd97b Binary files /dev/null and b/video/wan2.1/static/out_i2v_astronaut.gif differ diff --git a/video/wan2.1/static/out_t2v_1_3b.gif b/video/wan2.1/static/out_t2v_1_3b.gif new file mode 100644 index 000000000..41840133e Binary files /dev/null and b/video/wan2.1/static/out_t2v_1_3b.gif differ diff --git a/video/wan2.1/static/out_t2v_1_3b_teacache_005.gif b/video/wan2.1/static/out_t2v_1_3b_teacache_005.gif new file mode 100644 index 000000000..df75b777e Binary files /dev/null and b/video/wan2.1/static/out_t2v_1_3b_teacache_005.gif differ diff --git a/video/wan2.1/static/out_t2v_1_3b_teacache_01.gif b/video/wan2.1/static/out_t2v_1_3b_teacache_01.gif new file mode 100644 index 000000000..bcfde0a9f Binary files /dev/null and b/video/wan2.1/static/out_t2v_1_3b_teacache_01.gif differ diff --git a/video/wan2.1/static/out_t2v_1_3b_teacache_025.gif b/video/wan2.1/static/out_t2v_1_3b_teacache_025.gif new file mode 100644 index 000000000..a9c1d8849 Binary files /dev/null and b/video/wan2.1/static/out_t2v_1_3b_teacache_025.gif differ diff --git a/video/wan2.1/static/out_t2v_cats.gif b/video/wan2.1/static/out_t2v_cats.gif new file mode 100644 index 000000000..4587d8bc0 Binary files /dev/null and b/video/wan2.1/static/out_t2v_cats.gif differ diff --git a/video/wan2.1/txt2video.py b/video/wan2.1/txt2video.py new file mode 100644 index 000000000..66c8923ab --- /dev/null +++ b/video/wan2.1/txt2video.py @@ -0,0 +1,154 @@ +# Copyright © 2026 Apple Inc. + +"""Generate videos from text using Wan2.1.""" + +import argparse +import logging + +import mlx.core as mx +import mlx.nn as nn +from tqdm import tqdm +from wan import WanPipeline +from wan.utils import save_video + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate videos from text using Wan2.1" + ) + parser.add_argument("prompt") + parser.add_argument("--model", choices=["t2v-1.3B", "t2v-14B"], default="t2v-1.3B") + parser.add_argument( + "--size", + type=lambda x: tuple(map(int, x.split("x"))), + default=(832, 480), + help="Video size as WxH (default: 832x480)", + ) + parser.add_argument("--frames", type=int, default=81) + parser.add_argument( + "--steps", type=int, default=50, help="Number of denoising steps" + ) + parser.add_argument("--guidance", type=float, default=5.0) + parser.add_argument("--shift", type=float, default=5.0) + parser.add_argument("--seed", type=int) + parser.add_argument( + "--quantize", + "-q", + type=int, + nargs="?", + const=8, + default=0, + choices=[0, 4, 8], + metavar="{4,8}", + help="Quantize DiT weights (default: 8-bit when flag used without value)", + ) + parser.add_argument( + "--n-prompt", + default="Text, watermarks, blurry image, JPEG artifacts", + ) + parser.add_argument( + "--teacache", + type=float, + default=0.0, + help="TeaCache threshold for step skipping (0=off, 0.05=recommended)", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="Path to custom DiT weights (.safetensors), e.g. distilled models", + ) + parser.add_argument( + "--sampler", + choices=["unipc", "euler"], + default="unipc", + help="Sampler: unipc (default) or euler (for step-distilled models)", + ) + parser.add_argument("--output", default="out.mp4") + parser.add_argument("--preload-models", action="store_true") + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable Metal buffer cache (mx.set_cache_limit(0)) to reduce swap pressure", + ) + parser.add_argument("--verbose", "-v", action="store_true") + args = parser.parse_args() + + if args.sampler == "euler": + # Evenly spaced steps: e.g. 4 steps -> [1000, 750, 500, 250] + n = args.steps + denoising_step_list = [1000 * i // n for i in range(n, 0, -1)] + else: + denoising_step_list = None + + mx.set_default_device(mx.gpu) + if args.no_cache: + mx.set_cache_limit(0) + + if args.verbose: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + logging.getLogger("wan").setLevel(logging.INFO) + logging.getLogger("wan").addHandler(handler) + + # Load pipeline + pipeline = WanPipeline(args.model, checkpoint=args.checkpoint) + + # Quantize DiT + if args.quantize: + nn.quantize(pipeline.flow, bits=args.quantize) + print(f"Quantized DiT to {args.quantize}-bit") + + if args.preload_models: + pipeline.ensure_models_are_loaded() + + # Generate latents (generator pattern matching flux) + latents = pipeline.generate_latents( + args.prompt, + negative_prompt=args.n_prompt, + size=args.size, + frame_num=args.frames, + num_steps=args.steps, + guidance=args.guidance, + shift=args.shift, + seed=args.seed, + teacache=args.teacache, + verbose=args.verbose, + denoising_step_list=denoising_step_list, + ) + + # 1. Conditioning + conditioning = next(latents) + mx.eval(conditioning) + peak_mem_conditioning = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() + + # Free T5 memory + del pipeline.t5 + mx.clear_cache() + + # 2. Denoising loop + for x_t in tqdm(latents, total=args.steps): + mx.eval(x_t) + + # Free DiT memory + del pipeline.flow + mx.clear_cache() + peak_mem_generation = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() + + # 3. VAE decode + video = pipeline.decode(x_t) + mx.eval(video) + peak_mem_decoding = mx.get_peak_memory() / 1024**3 + + # Save video + save_video(video, args.output) + + if args.verbose: + peak_mem_overall = max( + peak_mem_conditioning, peak_mem_generation, peak_mem_decoding + ) + print(f"Peak memory conditioning: {peak_mem_conditioning:.3f}GB") + print(f"Peak memory generation: {peak_mem_generation:.3f}GB") + print(f"Peak memory decoding: {peak_mem_decoding:.3f}GB") + print(f"Peak memory overall: {peak_mem_overall:.3f}GB") diff --git a/video/wan2.1/wan/__init__.py b/video/wan2.1/wan/__init__.py new file mode 100644 index 000000000..4bb5ca783 --- /dev/null +++ b/video/wan2.1/wan/__init__.py @@ -0,0 +1,3 @@ +# Copyright © 2026 Apple Inc. + +from .pipeline import WanPipeline diff --git a/video/wan2.1/wan/clip.py b/video/wan2.1/wan/clip.py new file mode 100644 index 000000000..c7bf03b9e --- /dev/null +++ b/video/wan2.1/wan/clip.py @@ -0,0 +1,253 @@ +# Copyright © 2026 Apple Inc. + +""" +CLIP ViT-H/14 vision encoder for Wan2.1 I2V. + +Ported from the OpenCLIP XLM-RoBERTa CLIP model. Only the visual encoder is +used — text encoder, post_norm, and projection head are discarded. + +Architecture: ViT-H/14 (image_size=224, patch_size=14, dim=1280, heads=16, +layers=32, mlp_ratio=4, pool=token, activation=gelu, pre_norm=True). + +At inference the first 31 of 32 transformer blocks are used +(use_31_block=True). Block 31's weights are loaded but never evaluated. +""" + +import re + +import mlx.core as mx +import mlx.nn as nn + + +class CLIPAttentionBlock(nn.Module): + """Pre-norm transformer block with self-attention and MLP.""" + + def __init__(self, dim: int = 1280, num_heads: int = 16, mlp_ratio: int = 4): + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.self_attn = _CLIPSelfAttention(dim, num_heads) + self.norm2 = nn.LayerNorm(dim) + self.mlp = _CLIPMLP(dim, int(dim * mlp_ratio)) + + def __call__(self, x: mx.array) -> mx.array: + x = x + self.self_attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class _CLIPSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.q_proj = nn.Linear(dim, dim) + self.k_proj = nn.Linear(dim, dim) + self.v_proj = nn.Linear(dim, dim) + self.out_proj = nn.Linear(dim, dim) + + def __call__(self, x: mx.array) -> mx.array: + B, L, _ = x.shape + n, d = self.num_heads, self.head_dim + q = self.q_proj(x).reshape(B, L, n, d).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, n, d).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, n, d).transpose(0, 2, 1, 3) + x = mx.fast.scaled_dot_product_attention(q, k, v, scale=d**-0.5) + x = x.transpose(0, 2, 1, 3).reshape(B, L, n * d) + return self.out_proj(x) + + +class _CLIPMLP(nn.Module): + def __init__(self, dim: int, mid_dim: int): + super().__init__() + self.fc1 = nn.Linear(dim, mid_dim) + self.fc2 = nn.Linear(mid_dim, dim) + + def __call__(self, x: mx.array) -> mx.array: + return self.fc2(nn.gelu(self.fc1(x))) + + +class CLIPVisionEncoder(nn.Module): + """ViT-H/14 vision encoder returning patch + CLS token features.""" + + def __init__( + self, + image_size: int = 224, + patch_size: int = 14, + dim: int = 1280, + num_heads: int = 16, + num_layers: int = 32, + mlp_ratio: int = 4, + ): + super().__init__() + self.num_patches = (image_size // patch_size) ** 2 # 256 + self.dim = dim + self.num_layers = num_layers + + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False + ) + self.cls_embedding = mx.zeros((1, 1, dim)) + self.position_embedding = mx.zeros((1, self.num_patches + 1, dim)) + self.pre_norm = nn.LayerNorm(dim) + + # All 32 blocks loaded; only first 31 used in forward (see __call__). + # Block 31 weights are loaded but never evaluated. + for i in range(num_layers): + setattr(self, f"block_{i}", CLIPAttentionBlock(dim, num_heads, mlp_ratio)) + + def __call__(self, x: mx.array) -> mx.array: + """ + Args: + x: [B, 224, 224, 3] preprocessed image (channels-last). + + Returns: + [B, 257, 1280] CLS + patch token features. + """ + B = x.shape[0] + + # Patch embed: [B, H, W, 3] -> [B, H', W', dim] -> [B, num_patches, dim] + x = self.patch_embedding(x) + x = x.reshape(B, -1, self.dim) + + cls = mx.broadcast_to(self.cls_embedding, (B, 1, self.dim)) + x = mx.concatenate([cls, x], axis=1) + x = x + self.position_embedding + x = self.pre_norm(x) + + # Only first 31 of 32 blocks (matching reference use_31_block=True) + for i in range(self.num_layers - 1): + block = getattr(self, f"block_{i}") + x = block(x) + + return x + + @staticmethod + def sanitize(weights): + """Remap CLIP .pth checkpoint keys to MLX model format. + + Handles both standard OpenCLIP naming and Wan2.1 HF naming. + Extracts only visual.* keys. Splits fused QKV into q/k/v. + Discards post_norm, head, and all non-visual keys. + """ + remapped = {} + for key, value in weights.items(): + if not key.startswith("visual."): + continue + + # Skip post_norm, head + if "post_norm" in key or "ln_post" in key or key == "visual.head": + continue + + # patch_embedding + if key in ("visual.conv1.weight", "visual.patch_embedding.weight"): + if value.ndim == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + remapped["patch_embedding.weight"] = value + continue + + # CLS embedding + if key in ("visual.class_embedding", "visual.cls_embedding"): + remapped["cls_embedding"] = value.reshape(1, 1, -1) + continue + + # Position embedding + if key in ("visual.positional_embedding", "visual.pos_embedding"): + if value.ndim == 2: + value = value.reshape(1, value.shape[0], value.shape[1]) + remapped["position_embedding"] = value + continue + + # Pre-norm (both "visual.ln_pre.*" and "visual.pre_norm.*") + if key.startswith("visual.ln_pre.") or key.startswith("visual.pre_norm."): + param = key.split(".")[-1] + remapped[f"pre_norm.{param}"] = value + continue + + # Transformer blocks — handle both naming conventions: + # OpenCLIP: visual.transformer.resblocks.N.* + # Wan2.1 HF: visual.transformer.N.* + m = re.match(r"visual\.transformer\.(?:resblocks\.)?(\d+)\.(.*)", key) + if m: + block_idx = m.group(1) + rest = m.group(2) + + # Fused QKV: "attn.in_proj_*" or "attn.to_qkv.*" + if rest in ("attn.in_proj_weight", "attn.to_qkv.weight"): + dim = value.shape[0] // 3 + q, k, v = value[:dim], value[dim : 2 * dim], value[2 * dim :] + remapped[f"block_{block_idx}.self_attn.q_proj.weight"] = q + remapped[f"block_{block_idx}.self_attn.k_proj.weight"] = k + remapped[f"block_{block_idx}.self_attn.v_proj.weight"] = v + continue + + if rest in ("attn.in_proj_bias", "attn.to_qkv.bias"): + dim = value.shape[0] // 3 + q, k, v = value[:dim], value[dim : 2 * dim], value[2 * dim :] + remapped[f"block_{block_idx}.self_attn.q_proj.bias"] = q + remapped[f"block_{block_idx}.self_attn.k_proj.bias"] = k + remapped[f"block_{block_idx}.self_attn.v_proj.bias"] = v + continue + + # Out projection: "attn.out_proj.*" or "attn.proj.*" + for prefix in ("attn.out_proj.", "attn.proj."): + if rest.startswith(prefix): + param = rest.split(".")[-1] + remapped[f"block_{block_idx}.self_attn.out_proj.{param}"] = ( + value + ) + break + else: + # Norms: "ln_1.*" or "norm1.*", "ln_2.*" or "norm2.*" + for old, new in [ + ("ln_1.", "norm1."), + ("norm1.", "norm1."), + ("ln_2.", "norm2."), + ("norm2.", "norm2."), + ]: + if rest.startswith(old): + param = rest.split(".")[-1] + remapped[f"block_{block_idx}.{new}{param}"] = value + break + else: + # MLP: "mlp.c_fc.*" or "mlp.0.*", "mlp.c_proj.*" or "mlp.2.*" + for old, new in [ + ("mlp.c_fc.", "mlp.fc1."), + ("mlp.0.", "mlp.fc1."), + ("mlp.c_proj.", "mlp.fc2."), + ("mlp.2.", "mlp.fc2."), + ]: + if rest.startswith(old): + param = rest.split(".")[-1] + remapped[f"block_{block_idx}.{new}{param}"] = value + break + + return remapped + + +def preprocess_clip_image(image_path: str) -> mx.array: + """Load and preprocess an image for CLIP ViT-H/14. + + The reference CLIP visual() receives images in [-1, 1], then does + mul_(0.5).add_(0.5) to get [0, 1], then applies normalize. We replicate + this: load -> resize 224x224 -> [0,1] -> normalize. + + Returns: + [1, 224, 224, 3] float32 array (channels-last for MLX). + """ + from PIL import Image + + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + img = Image.open(image_path).convert("RGB") + img = img.resize((224, 224), Image.BICUBIC) + + # Convert to float32 [0, 1] + import numpy as np + + arr = np.array(img).astype(np.float32) / 255.0 + + # Normalize per-channel + arr = (arr - np.array(mean, dtype=np.float32)) / np.array(std, dtype=np.float32) + + return mx.array(arr[np.newaxis]) # [1, 224, 224, 3] diff --git a/video/wan2.1/wan/layers.py b/video/wan2.1/wan/layers.py new file mode 100644 index 000000000..79c904fd7 --- /dev/null +++ b/video/wan2.1/wan/layers.py @@ -0,0 +1,251 @@ +# Copyright © 2026 Apple Inc. + +""" +Transformer layers for Wan2.1 DiT. + +Norms, attention, blocks, and output head. Uses bidirectional (non-causal) +attention with fused norm+modulate via mx.fast.layer_norm. +""" + +import math +from functools import partial +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .rope import rope_apply + + +# Compiled to fuse x + y * gate into a single Metal kernel (hot path). +@partial(mx.compile, shapeless=True) +def _residual_gate(x, y, gate): + return x + y * gate + + +class WanSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + eps: float = 1e-6, + ): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3) + self.o = nn.Linear(dim, dim) + + self.norm_q = nn.RMSNorm(dim, eps=eps) + self.norm_k = nn.RMSNorm(dim, eps=eps) + + def _attend(self, x, grid_sizes): + """Compute self-attention. Returns attn output [B, n, L, d].""" + B, L, _ = x.shape + n, d = self.num_heads, self.head_dim + + qkv = self.qkv(x) + q, k, v = mx.split(qkv, 3, axis=-1) + + q = self.norm_q(q) + k = self.norm_k(k) + + q = q.reshape(B, L, n, d) + k = k.reshape(B, L, n, d) + v = v.reshape(B, L, n, d) + + q = rope_apply(q, grid_sizes, self.head_dim) + k = rope_apply(k, grid_sizes, self.head_dim) + + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + return mx.fast.scaled_dot_product_attention(q, k, v, scale=self.head_dim**-0.5) + + def __call__(self, x, grid_sizes): + B, L, C = x.shape + attn = self._attend(x, grid_sizes) + return self.o(attn.transpose(0, 2, 1, 3).reshape(B, L, C)) + + +class WanCrossAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + eps: float = 1e-6, + ): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q = nn.Linear(dim, dim) + self.kv = nn.Linear(dim, dim * 2) + self.o = nn.Linear(dim, dim) + + self.norm_q = nn.RMSNorm(dim, eps=eps) + self.norm_k = nn.RMSNorm(dim, eps=eps) + + def _attend(self, x, context): + """Compute text cross-attention. Returns (q, attn_out) both [B, n, L, d].""" + B = x.shape[0] + L1, L2 = x.shape[1], context.shape[1] + n, d = self.num_heads, self.head_dim + + q = self.norm_q(self.q(x)) + kv = self.kv(context) + k, v = mx.split(kv, 2, axis=-1) + k = self.norm_k(k) + + q = q.reshape(B, L1, n, d).transpose(0, 2, 1, 3) + k = k.reshape(B, L2, n, d).transpose(0, 2, 1, 3) + v = v.reshape(B, L2, n, d).transpose(0, 2, 1, 3) + + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=d**-0.5) + + return q, out + + def __call__(self, x, context): + _, attn = self._attend(x, context) + B, _, L1, _ = attn.shape + x = attn.transpose(0, 2, 1, 3).reshape(B, L1, self.dim) + return self.o(x) + + +# T5 text tokens in context; remaining tokens are CLIP image tokens (I2V only). +T5_CONTEXT_TOKEN_NUMBER = 512 + + +class WanI2VCrossAttention(WanCrossAttention): + """Cross-attention with separate image and text paths for I2V.""" + + def __init__(self, dim: int, num_heads: int, eps: float = 1e-6): + super().__init__(dim, num_heads, eps) + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + self.norm_k_img = nn.RMSNorm(dim, eps=eps) + + def __call__(self, x, context): + img_ctx_len = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER + context_img = context[:, :img_ctx_len] + context_txt = context[:, img_ctx_len:] + + # Text attention + q, x_txt = self._attend(x, context_txt) + + # Image attention: reuses q from text path (q encodes the latent, not the context) + B, L1 = x.shape[:2] + n, d = self.num_heads, self.head_dim + L_img = context_img.shape[1] + ki = self.norm_k_img(self.k_img(context_img)) + vi = self.v_img(context_img) + ki = ki.reshape(B, L_img, n, d).transpose(0, 2, 1, 3) + vi = vi.reshape(B, L_img, n, d).transpose(0, 2, 1, 3) + x_img = mx.fast.scaled_dot_product_attention(q, ki, vi, scale=d**-0.5) + + x = (x_txt + x_img).transpose(0, 2, 1, 3).reshape(B, L1, self.dim) + return self.o(x) + + +_cross_attn_classes = { + "t2v": WanCrossAttention, + "i2v": WanI2VCrossAttention, +} + + +class WanAttentionBlock(nn.Module): + """ + Transformer block with self-attn, cross-attn, and FFN. + + Uses fused norm+modulate via mx.fast.layer_norm where the modulation + scale/shift are passed as weight/bias. Requires sanitize to bake 1+ + into modulation scale positions. + """ + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + cross_attn_norm: bool = False, + eps: float = 1e-6, + cross_attn_type: str = "t2v", + ): + super().__init__() + self.dim = dim + self.eps = eps + + if cross_attn_norm: + self.norm3 = nn.LayerNorm(dim, eps=eps) + else: + self.norm3 = None + + self.self_attn = WanSelfAttention(dim, num_heads, eps) + self.cross_attn = _cross_attn_classes[cross_attn_type](dim, num_heads, eps) + + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), + nn.GELU(approx="tanh"), + nn.Linear(ffn_dim, dim), + ) + + # Modulation: [shift, scale, gate] x 2 for self-attn (indices 0-2) and FFN (indices 3-5) + self.modulation = mx.zeros((1, 6, dim)) + + def __call__( + self, + x: mx.array, + e: mx.array, + grid_sizes: list, + context: mx.array, + ) -> mx.array: + e = self.modulation + e + + # Self-attention: fused LayerNorm where e[:,1]=scale (weight), e[:,0]=shift (bias), e[:,2]=gate + y = self.self_attn( + mx.fast.layer_norm(x, e[0, 1], e[0, 0], self.eps), + grid_sizes, + ) + x = _residual_gate(x, y, e[:, 2]) + + # Cross-attention + if self.norm3 is not None: + x_normed = self.norm3(x) + else: + x_normed = x + x = x + self.cross_attn(x_normed, context) + + # FFN: fused LayerNorm where e[:,4]=scale, e[:,3]=shift, e[:,5]=gate + y = self.ffn(mx.fast.layer_norm(x, e[0, 4], e[0, 3], self.eps)) + x = _residual_gate(x, y, e[:, 5]) + + return x + + +class Head(nn.Module): + """Output head with fused norm+modulate and nn.Linear.""" + + def __init__( + self, + dim: int, + out_dim: int, + patch_size: Tuple[int, int, int], + eps: float = 1e-6, + ): + super().__init__() + self.dim = dim + self.eps = eps + out_features = math.prod(patch_size) * out_dim + self.linear = nn.Linear(dim, out_features) + # Modulation: [shift, scale] for output head norm + self.modulation = mx.zeros((1, 2, dim)) + + def __call__(self, x: mx.array, e: mx.array) -> mx.array: + e = self.modulation + e[:, None, :] + x = mx.fast.layer_norm(x, e[0, 1], e[0, 0], self.eps) + return self.linear(x) diff --git a/video/wan2.1/wan/model.py b/video/wan2.1/wan/model.py new file mode 100644 index 000000000..5285ad967 --- /dev/null +++ b/video/wan2.1/wan/model.py @@ -0,0 +1,298 @@ +# Copyright © 2026 Apple Inc. + +""" +Wan2.1 bidirectional DiT (Diffusion Transformer) for video generation. + +Supports 1.3B and 14B model sizes with text-to-video (t2v) and +image-to-video (i2v) modes. Uses bidirectional attention with +nn.Sequential embeddings and list-based block storage. +""" + +import math +import re +from functools import partial +from typing import Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from einops import rearrange + +from .layers import Head, WanAttentionBlock + + +# shapeless=True: avoids recompilation across varying input shapes. +@partial(mx.compile, shapeless=True) +def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array: + assert dim % 2 == 0 + half = dim // 2 + dtype = position.dtype + position = position.astype(mx.float32) + sinusoid = ( + position[:, None] + * mx.exp(-math.log(10000) * mx.arange(half, dtype=mx.float32) / half)[None, :] + ) + return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1).astype(dtype) + + +class WanModel(nn.Module): + def __init__( + self, + model_type: str = "t2v", + patch_size: Tuple[int, int, int] = (1, 2, 2), + text_len: int = 512, + in_dim: int = 16, + dim: int = 2048, + ffn_dim: int = 8192, + freq_dim: int = 256, + text_dim: int = 4096, + out_dim: int = 16, + num_heads: int = 16, + num_layers: int = 32, + cross_attn_norm: bool = True, + eps: float = 1e-6, + ): + super().__init__() + self.patch_size = patch_size + self.dim = dim + self.freq_dim = freq_dim + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size, bias=True + ) + + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), nn.GELU(approx="tanh"), nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim)) + + # Image embedding MLP for I2V: LayerNorm -> Linear -> GELU -> Linear -> LayerNorm + if model_type == "i2v": + clip_dim = 1280 + self.img_emb_norm1 = nn.LayerNorm(clip_dim) + self.img_emb_linear1 = nn.Linear(clip_dim, clip_dim) + self.img_emb_linear2 = nn.Linear(clip_dim, dim) + self.img_emb_norm2 = nn.LayerNorm(dim) + + # Transformer blocks as list + self.blocks = [ + WanAttentionBlock( + dim, + ffn_dim, + num_heads, + cross_attn_norm, + eps, + cross_attn_type=model_type, + ) + for _ in range(num_layers) + ] + + # Output head + self.head = Head(dim, out_dim, patch_size, eps) + + def _embed_image(self, clip_fea: mx.array) -> mx.array: + """Project CLIP features through img_emb MLP.""" + x = self.img_emb_norm1(clip_fea) + x = self.img_emb_linear1(x) + x = nn.gelu(x) + x = self.img_emb_linear2(x) + x = self.img_emb_norm2(x) + return x + + def compute_time_embedding(self, t: mx.array): + """Compute time embeddings for TeaCache. Returns (t_emb, e0). + t_emb: [1, dim] (pre-projection, used by head) + e0: [1, 6*dim] (projected, used for block modulation)""" + e = sinusoidal_embedding_1d(self.freq_dim, t) + t_emb = self.time_embedding(e) + e0 = self.time_projection(t_emb) + return t_emb, e0 + + def __call__( + self, + x: mx.array, + t: mx.array, + context: mx.array, + block_residual: Optional[mx.array] = None, + precomputed_time: Optional[Tuple[mx.array, mx.array]] = None, + clip_fea: Optional[mx.array] = None, + first_frame: Optional[mx.array] = None, + ) -> Tuple[mx.array, mx.array]: + """ + Forward pass for t2v and i2v. + + Args: + x: Input latent [F, H, W, C_in] (channels-last) + t: Timestep [1] + context: Text embedding [L, C_text] + block_residual: Precomputed block residual for TeaCache skip + precomputed_time: (t_emb, e0) tuple for TeaCache + clip_fea: CLIP image features [1, 257, 1280] (I2V only) + first_frame: Image conditioning [F, H, W, C_cond] (I2V only). + Concatenated channel-wise with x before patchify (in_dim=36). + + Returns: + (output, block_residual): output latent [F, H, W, C_out] and + block residual for TeaCache caching (None-equivalent zeros when + using cached residual). + """ + # Channel-concat image conditioning before patchify (I2V) + if first_frame is not None: + x = mx.concatenate([x, first_frame], axis=-1) + + # Patchify: [F, H, W, C] -> [1, F, H, W, C] -> conv3d -> [1, Fp, Hp, Wp, dim] + x = self.patch_embedding(x[None]) + _, Fp, Hp, Wp, _ = x.shape + grid_sizes = [[Fp, Hp, Wp]] + x = x.reshape(1, Fp * Hp * Wp, self.dim) + + # Embed context: [L, C_text] -> [1, text_len, dim] + context = self.text_embedding(context[None]) + + # Prepend projected CLIP features to context (I2V) + if clip_fea is not None: + clip_proj = self._embed_image(clip_fea) + context = mx.concatenate([clip_proj, context], axis=1) + + # Time embedding + if precomputed_time is not None: + t_emb, e = precomputed_time[0], precomputed_time[1] + else: + e = sinusoidal_embedding_1d(self.freq_dim, t) + t_emb = self.time_embedding(e) + e = self.time_projection(t_emb) + e = e.reshape(1, 6, self.dim) + + # Transformer blocks + if block_residual is not None: + x = x + block_residual + new_residual = block_residual # pass through (caller won't cache this) + else: + x_in = x + for block in self.blocks: + x = block(x, e, grid_sizes, context) + new_residual = x - x_in + + # Output head + x = self.head(x, t_emb) + + # Unpatchify: [1, seq_len, patch_features] -> [F, H, W, C] + pt, ph, pw = self.patch_size + output = rearrange( + x[0], + "(Fp Hp Wp) (pt ph pw c) -> (Fp pt) (Hp ph) (Wp pw) c", + Fp=Fp, + Hp=Hp, + Wp=Wp, + pt=pt, + ph=ph, + pw=pw, + ) + return output, new_residual + + @staticmethod + def sanitize(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Remap PyTorch checkpoint keys to MLX model format.""" + remapped = {} + for key, value in weights.items(): + new_key = key + + # Skip fp8 scale metadata from LightX2V quantized checkpoints + if "weight_scale" in new_key: + continue + + # Remove model. prefix + if new_key.startswith("model."): + new_key = new_key[6:] + + # PyTorch Conv3d [O,I,kT,kH,kW] -> MLX Conv3d [O,kT,kH,kW,I] + if ( + "patch_embedding" in new_key + and "weight" in new_key + and len(value.shape) == 5 + ): + value = mx.transpose(value, (0, 2, 3, 4, 1)) + + # PyTorch nn.Sequential uses flat keys ("ffn.0."), MLX nests under ".layers." ("ffn.layers.0.") + new_key = new_key.replace("ffn.0.", "ffn.layers.0.") + new_key = new_key.replace("ffn.2.", "ffn.layers.2.") + + new_key = new_key.replace("text_embedding.0.", "text_embedding.layers.0.") + new_key = new_key.replace("text_embedding.2.", "text_embedding.layers.2.") + + new_key = new_key.replace("time_embedding.0.", "time_embedding.layers.0.") + new_key = new_key.replace("time_embedding.2.", "time_embedding.layers.2.") + + new_key = new_key.replace("time_projection.1.", "time_projection.layers.1.") + + # head.head -> head.linear + new_key = new_key.replace("head.head.", "head.linear.") + + # img_emb.proj.N -> img_emb_* (I2V MLPProj) + new_key = re.sub(r"img_emb\.proj\.0\.(\w+)", r"img_emb_norm1.\1", new_key) + new_key = re.sub(r"img_emb\.proj\.1\.(\w+)", r"img_emb_linear1.\1", new_key) + new_key = re.sub(r"img_emb\.proj\.3\.(\w+)", r"img_emb_linear2.\1", new_key) + new_key = re.sub(r"img_emb\.proj\.4\.(\w+)", r"img_emb_norm2.\1", new_key) + + remapped[new_key] = value + + # Merge separate Q/K/V into QKV for self-attention, + # and K/V into KV for cross-attention + remapped = WanModel._merge_qkv_weights(remapped) + + # Modulation vectors are [shift, scale, gate, ...]. The DiT block applies + # them as x * (1 + scale) + shift, but we fuse the "1 +" into the stored + # scale weights here so the forward pass is just x * scale + shift. + for key in list(remapped.keys()): + if key.endswith(".modulation"): + v = remapped[key] + if v.shape[1] == 6: # block modulation [1, 6, dim] + # Add 1 to scale positions (1 and 4) + remapped[key] = v + mx.array([0, 1, 0, 0, 1, 0])[:, None] + elif v.shape[1] == 2: # head modulation [1, 2, dim] + remapped[key] = v + mx.array([0, 1])[:, None] + + return remapped + + @staticmethod + def _merge_qkv_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Merge separate q/k/v weights into qkv (self-attn) and kv (cross-attn).""" + merged = {} + consumed = set() + + for key in weights: + # Self-attention: merge q, k, v -> qkv + m = re.match(r"(blocks\.\d+\.self_attn)\.(q)\.(weight|bias)$", key) + if m: + prefix, _, param = m.groups() + q_key = f"{prefix}.q.{param}" + k_key = f"{prefix}.k.{param}" + v_key = f"{prefix}.v.{param}" + if q_key in weights and k_key in weights and v_key in weights: + merged[f"{prefix}.qkv.{param}"] = mx.concatenate( + [weights[q_key], weights[k_key], weights[v_key]], axis=0 + ) + consumed.update([q_key, k_key, v_key]) + continue + + # Cross-attention: merge k, v -> kv (q stays separate) + m = re.match(r"(blocks\.\d+\.cross_attn)\.(k)\.(weight|bias)$", key) + if m: + prefix, _, param = m.groups() + k_key = f"{prefix}.k.{param}" + v_key = f"{prefix}.v.{param}" + if k_key in weights and v_key in weights: + merged[f"{prefix}.kv.{param}"] = mx.concatenate( + [weights[k_key], weights[v_key]], axis=0 + ) + consumed.update([k_key, v_key]) + continue + + # Copy all non-consumed keys + for key, value in weights.items(): + if key not in consumed: + merged[key] = value + + return merged diff --git a/video/wan2.1/wan/pipeline.py b/video/wan2.1/wan/pipeline.py new file mode 100644 index 000000000..dd83c7746 --- /dev/null +++ b/video/wan2.1/wan/pipeline.py @@ -0,0 +1,404 @@ +# Copyright © 2026 Apple Inc. + +""" +Wan2.1 text-to-video and image-to-video pipeline. +""" + +import logging +from typing import Optional, Tuple + +import mlx.core as mx +import numpy as np + +logger = logging.getLogger(__name__) + +from .sampler import FlowEulerDiscreteScheduler, FlowUniPCMultistepScheduler +from .utils import configs, load_clip, load_dit, load_t5, load_t5_tokenizer, load_vae + +# Polynomial coefficients for TeaCache distance rescaling (calibrated per model). +# Each entry has keys: coeffs, ret_steps, use_e0 +_tea_coeffs = { # from https://github.com/ModelTC/LightX2V/blob/main/configs/caching/teacache/wan_t2v_1_3b_tea_480p.json + "t2v-1.3B": { + "coeffs": [ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02, + ], + "ret_steps": 5, + "use_e0": True, + }, + "t2v-14B": { # from https://github.com/ModelTC/LightX2V/blob/main/configs/caching/custom/wan_t2v_custom_14b.json + "coeffs": [ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + "ret_steps": 1, + "use_e0": False, + }, + "i2v-14B": { # from https://github.com/ModelTC/LightX2V/blob/main/configs/caching/teacache/wan_i2v_tea_480p.json + "coeffs": [ + 2.57151496e05, + -3.54229917e04, + 1.40286849e03, + -1.35890334e01, + 1.32517977e-01, + ], + "ret_steps": 5, + "use_e0": True, + }, +} + + +class WanPipeline: + def __init__( + self, + name: str = "t2v-1.3B", + dtype: mx.Dtype = mx.bfloat16, + checkpoint: Optional[str] = None, + ): + self.dtype = dtype + self.name = name + self.vae_stride = (4, 8, 8) + self.z_dim = 16 + self._null_context = None + + self.flow = load_dit(name, checkpoint=checkpoint) + self.vae = load_vae(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.clip = load_clip(name) if configs[name].repo_clip else None + self.sampler = FlowUniPCMultistepScheduler() + + def ensure_models_are_loaded(self): + params = [ + self.flow.parameters(), + self.vae.parameters(), + self.t5.parameters(), + ] + if self.clip is not None: + params.append(self.clip.parameters()) + mx.eval(*params) + + def _encode_text(self, text: str) -> mx.array: + """Encode text prompt with T5. Returns [512, 4096].""" + tokens = self.t5_tokenizer(text) + ids = tokens["input_ids"] + mask = tokens["attention_mask"] + embeddings = self.t5(ids, mask=mask) + seq_len = int(mask.sum().item()) + context = embeddings[0, :seq_len, :] + if seq_len < 512: + padding = mx.zeros((512 - seq_len, context.shape[-1])) + context = mx.concatenate([context, padding], axis=0) + return context + + def _encode_null(self) -> mx.array: + """Return cached empty-string T5 embedding for CFG.""" + if self._null_context is None: + self._null_context = self._encode_text("") + return self._null_context + + def _encode_clip(self, image_path: str) -> mx.array: + """Encode image with CLIP. Returns [1, 257, 1280].""" + from .clip import preprocess_clip_image + + img = preprocess_clip_image(image_path) + return self.clip(img).astype(self.dtype) + + def _prepare_image_conditioning( + self, image_path: str, size: Tuple[int, int], frame_num: int + ) -> mx.array: + """Prepare VAE-encoded first frame + temporal mask. + + Returns: + y: [T', H', W', 20] conditioning tensor (channels-last) + """ + from PIL import Image + + W, H = size + T_latent = (frame_num - 1) // self.vae_stride[0] + 1 + H_latent = H // self.vae_stride[1] + W_latent = W // self.vae_stride[2] + + # Load image, resize (short side) + center crop to target resolution + img = Image.open(image_path).convert("RGB") + iw, ih = img.size + scale = max(W / iw, H / ih) + rw, rh = round(iw * scale), round(ih * scale) + img = img.resize((rw, rh), Image.BICUBIC) + left = (rw - W) // 2 + top = (rh - H) // 2 + img = img.crop((left, top, left + W, top + H)) + + # Normalize to [-1, 1] + img_arr = np.array(img).astype(np.float32) / 255.0 + img_arr = (img_arr - 0.5) / 0.5 + img_tensor = mx.array(img_arr) # [H, W, 3] + + # Build video: first frame = image, rest = zeros -> [F, H, W, 3] + zeros = mx.zeros((frame_num - 1, H, W, 3)) + video = mx.concatenate([img_tensor[None], zeros], axis=0) + + # VAE encode -> [T', H', W', 16] + vae_latent = self.vae.encode(video) + + # Build temporal mask -> [T', H', W', 4] + msk_first = mx.ones((1, H_latent, W_latent, 4)) + msk_rest = mx.zeros((T_latent - 1, H_latent, W_latent, 4)) + msk = mx.concatenate([msk_first, msk_rest], axis=0) + + # Concat: [T', H', W', 4+16] = [T', H', W', 20] + y = mx.concatenate([msk, vae_latent], axis=-1) + return y.astype(self.dtype) + + def _precompute_teacache(self, sampler, num_steps, teacache): + """Precompute time embeddings and TeaCache skip schedule. + + Returns: + (all_t_emb, all_e0, skip_mask): Lists of precomputed embeddings + and a boolean list where True means skip (use cached residual). + """ + tea_cfg = _tea_coeffs[self.name] + coeffs = mx.array(tea_cfg["coeffs"], dtype=mx.float64) + ret_steps = tea_cfg["ret_steps"] + use_e0 = tea_cfg["use_e0"] + cutoff_steps = num_steps if use_e0 else num_steps - 1 + + # Precompute all time embeddings (float32, lazy) + all_t_emb = [] + all_e0 = [] + for t in sampler.timesteps: + t_val = t.reshape(1).astype(mx.float32) + t_emb, e0 = self.flow.compute_time_embedding(t_val) + all_t_emb.append(t_emb) + all_e0.append(e0) + + # Vectorized distance computation (lazy) + embs = mx.stack(all_e0 if use_e0 else all_t_emb) # [N, 1, D] + raw_dists = mx.abs(embs[1:] - embs[:-1]).mean(axis=(1, 2)) / ( + mx.abs(embs[:-1]).mean(axis=(1, 2)) + 1e-8 + ) + + # Polynomial rescaling in float64 on CPU + with mx.stream(mx.cpu): + dists_f64 = raw_dists.astype(mx.float64) + rescaled = coeffs[0] + for c in coeffs[1:]: + rescaled = rescaled * dists_f64 + c + rescaled = mx.abs(rescaled).astype(mx.float32) + + # Single eval materializes embeddings (GPU) + rescaled distances (CPU) + mx.eval(rescaled, *all_t_emb, *all_e0) + + # Simulate accumulation to build skip schedule + skip_mask = [] + accum = mx.array(0.0) + for step_idx in range(num_steps): + must_compute = ( + step_idx < ret_steps or step_idx >= cutoff_steps or step_idx == 0 + ) + if not must_compute: + accum += rescaled[step_idx - 1] + + should_skip = not must_compute and accum < teacache + skip_mask.append(should_skip) + if not should_skip: + accum = mx.array(0.0) + + return all_t_emb, all_e0, skip_mask + + def generate_latents( + self, + text: str, + image_path: Optional[str] = None, + negative_prompt: str = "", + size: Tuple[int, int] = (832, 480), + frame_num: int = 81, + num_steps: int = 50, + guidance: float = 5.0, + shift: float = 5.0, + seed: Optional[int] = None, + teacache: float = 0.0, + verbose: bool = False, + denoising_step_list=None, + ): + """ + Generator yielding latents at each denoising step. + + First yield: conditioning tuple (for mx.eval by caller) + Subsequent yields: latent at each denoising step + + Args: + image_path: Path to input image (I2V mode). None for T2V. + denoising_step_list: If provided, use Euler scheduler for + step-distilled models (e.g. [1000, 750, 500, 250]). + """ + if denoising_step_list is not None and teacache > 0: + logger.warning( + "TeaCache is not calibrated for distilled models; disabling." + ) + teacache = 0.0 + + if seed is not None: + mx.random.seed(seed) + + W, H = size + target_shape = ( + (frame_num - 1) // self.vae_stride[0] + 1, + H // self.vae_stride[1], + W // self.vae_stride[2], + self.z_dim, + ) + + # Encode text + context = self._encode_text(text) + if negative_prompt: + context_null = self._encode_text(negative_prompt) + else: + context_null = self._encode_null() + + # Image conditioning (I2V only) + clip_features = None + first_frame = None + if image_path is not None and self.clip is not None: + clip_features = self._encode_clip(image_path) + first_frame = self._prepare_image_conditioning(image_path, size, frame_num) + + # Initial noise + x_T = mx.random.normal(target_shape).astype(self.dtype) + + # Yield conditioning for controlled evaluation + yield (x_T, context, context_null, clip_features, first_frame) + + # Denoising loop — choose sampler + if denoising_step_list is not None: + sampler = FlowEulerDiscreteScheduler() + sampler.set_timesteps(denoising_step_list, shift=shift) + num_steps = len(denoising_step_list) + else: + sampler = self.sampler + sampler.set_timesteps(num_steps, shift=shift) + + # TeaCache state + use_teacache = teacache > 0 + if use_teacache: + # Must run before mx.compile(self.flow.__call__) below, since + # compute_time_embedding uses self.flow.state and mx.eval here + # materializes those parameters. + all_t_emb, all_e0, skip_mask = self._precompute_teacache( + sampler, num_steps, teacache + ) + prev_residual_cond = None + prev_residual_uncond = None + + if verbose: + n_skip = sum(skip_mask) + logger.info( + f"TeaCache: will skip {n_skip}/{num_steps} steps " + f"({100 * n_skip / num_steps:.0f}%)" + ) + + flow = mx.compile(self.flow.__call__, inputs=[self.flow.state]) + + x_t = x_T + for step_idx, t in enumerate(sampler.timesteps): + t_val = t.reshape(1).astype(mx.float32) + + if use_teacache: + precomputed = (all_t_emb[step_idx], all_e0[step_idx]) + + if skip_mask[step_idx]: + noise_cond, _ = flow( + x_t, + t=t_val, + context=context, + clip_fea=clip_features, + first_frame=first_frame, + block_residual=prev_residual_cond, + precomputed_time=precomputed, + ) + if verbose: + logger.info(f"Step {step_idx}/{num_steps}: skip") + else: + noise_cond, prev_residual_cond = flow( + x_t, + t=t_val, + context=context, + clip_fea=clip_features, + first_frame=first_frame, + precomputed_time=precomputed, + ) + mx.eval( + prev_residual_cond + ) # Materialize residual now so it persists for cached (skip) steps. + if verbose: + logger.info(f"Step {step_idx}/{num_steps}: compute") + + if guidance > 1.0: + if skip_mask[step_idx]: + noise_uncond, _ = flow( + x_t, + t=t_val, + context=context_null, + clip_fea=clip_features, + first_frame=first_frame, + block_residual=prev_residual_uncond, + precomputed_time=precomputed, + ) + else: + noise_uncond, prev_residual_uncond = flow( + x_t, + t=t_val, + context=context_null, + clip_fea=clip_features, + first_frame=first_frame, + precomputed_time=precomputed, + ) + mx.eval(prev_residual_uncond) + noise_pred = noise_uncond + guidance * (noise_cond - noise_uncond) + else: + noise_pred = noise_cond + else: + # Standard path + noise_cond, _ = flow( + x_t, + t=t_val, + context=context, + clip_fea=clip_features, + first_frame=first_frame, + ) + + if guidance > 1.0: + noise_uncond, _ = flow( + x_t, + t=t_val, + context=context_null, + clip_fea=clip_features, + first_frame=first_frame, + ) + noise_pred = noise_uncond + guidance * (noise_cond - noise_uncond) + else: + noise_pred = noise_cond + + # async_eval starts GPU work on x_t and returns immediately, + # so the caller's mx.eval blocks less (pipeline overlap). + x_t = sampler.step(noise_pred, t, x_t) + mx.async_eval(x_t) + yield x_t + + def decode(self, latents: mx.array) -> mx.array: + """ + Decode latents to video frames. + + Args: + latents: [F, H, W, C] latent tensor (channels-last) + + Returns: + [F, H, W, C] video tensor in [-1, 1] (channels-last) + """ + return self.vae.decode(latents) diff --git a/video/wan2.1/wan/rope.py b/video/wan2.1/wan/rope.py new file mode 100644 index 000000000..1b0818830 --- /dev/null +++ b/video/wan2.1/wan/rope.py @@ -0,0 +1,85 @@ +# Copyright © 2026 Apple Inc. + +""" +Rotary Position Embedding (RoPE) for 3D video transformers. + +Implements 3-axis RoPE for temporal, height, and width dimensions. +Uses mx.fast.rope for optimized Metal kernel. +""" + +from typing import Tuple + +import mlx.core as mx +from einops import rearrange + + +def get_rope_dimensions(head_dim: int) -> Tuple[int, int, int]: + """ + Get the dimension split for 3D RoPE. + + - Frame: d - 4*(d//6) + - Height: 2*(d//6) + - Width: 2*(d//6) + """ + d = head_dim + frame_dim = d - 4 * (d // 6) + height_dim = 2 * (d // 6) + width_dim = 2 * (d // 6) + return frame_dim, height_dim, width_dim + + +@mx.compile +def _rope_3d(x, f, h, w, frame_dim, height_dim, width_dim, theta): + B = x.shape[0] + + x_frame = x[..., :frame_dim] + x_height = x[..., frame_dim : frame_dim + height_dim] + x_width = x[..., frame_dim + height_dim :] + + # Frame RoPE + x_frame = rearrange(x_frame, "B (f hw) n d -> (B hw) n f d", f=f) + x_frame = mx.fast.rope( + x_frame, dims=frame_dim, traditional=True, base=theta, scale=1.0, offset=0 + ) + x_frame = rearrange(x_frame, "(B hw) n f d -> B (f hw) n d", B=B, f=f) + + # Height RoPE + x_height = rearrange(x_height, "B (f h w) n d -> (B f w) n h d", f=f, h=h, w=w) + x_height = mx.fast.rope( + x_height, dims=height_dim, traditional=True, base=theta, scale=1.0, offset=0 + ) + x_height = rearrange(x_height, "(B f w) n h d -> B (f h w) n d", B=B, f=f, w=w) + + # Width RoPE + x_width = rearrange(x_width, "B (f h w) n d -> (B f h) n w d", f=f, h=h, w=w) + x_width = mx.fast.rope( + x_width, dims=width_dim, traditional=True, base=theta, scale=1.0, offset=0 + ) + x_width = rearrange(x_width, "(B f h) n w d -> B (f h w) n d", B=B, f=f, h=h) + + return mx.concatenate([x_frame, x_height, x_width], axis=-1) + + +def rope_apply( + x: mx.array, + grid_sizes: list, + head_dim: int, + theta: float = 10000.0, +) -> mx.array: + """ + Apply 3D RoPE using mx.fast.rope with reshapes. + + Args: + x: Tensor of shape [B, L, H, D] + grid_sizes: List of [frames, height, width] per batch element + head_dim: Dimension per attention head + theta: RoPE base frequency + + Returns: + Rotated tensor with same shape as x + """ + f, h, w = grid_sizes[0] + + frame_dim, height_dim, width_dim = get_rope_dimensions(head_dim) + + return _rope_3d(x, f, h, w, frame_dim, height_dim, width_dim, theta) diff --git a/video/wan2.1/wan/sampler.py b/video/wan2.1/wan/sampler.py new file mode 100644 index 000000000..e13944d8f --- /dev/null +++ b/video/wan2.1/wan/sampler.py @@ -0,0 +1,415 @@ +# Copyright © 2026 Apple Inc. + +# Ported from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Converted for flow matching. + +""" +FlowUniPCMultistepScheduler for Wan2.1 denoising. + +UniPC multistep solver adapted for flow matching prediction. +""" + +from functools import partial +from typing import List, Optional + +import mlx.core as mx + + +def _lambda64(alpha: mx.array, sigma: mx.array) -> mx.array: + # log(alpha/sigma) needs float64 for numerical stability; Metal GPU doesn't support float64. + with mx.stream(mx.cpu): + result = mx.log(alpha.astype(mx.float64)) - mx.log(sigma.astype(mx.float64)) + return result.astype(mx.float32) + + +class FlowUniPCMultistepScheduler: + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: Optional[List[int]] = None, + final_sigmas_type: str = "zero", + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + solver_type = "bh2" + else: + raise NotImplementedError(f"{solver_type} not implemented") + + self.num_train_timesteps = num_train_timesteps + self.solver_order = solver_order + self.prediction_type = prediction_type + self.shift = shift + self.predict_x0 = predict_x0 + self.solver_type = solver_type + self.lower_order_final = lower_order_final + self.disable_corrector = ( + disable_corrector if disable_corrector is not None else [] + ) + self.final_sigmas_type = final_sigmas_type + + sigmas = ( + 1.0 - mx.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1] + ) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + sigmas = sigmas.astype(mx.float32) + + self.sigma_min = float(sigmas[-1].item()) + self.sigma_max = float(sigmas[0].item()) + + self.sigmas = None + self.timesteps = None + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.last_sample = None + self._step_index = None + + @property + def step_index(self): + return self._step_index + + def set_timesteps(self, num_inference_steps, shift=None): + sigmas = mx.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1)[ + :-1 + ] + if shift is None: + shift = self.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + sigma_last = 0 + timesteps = sigmas * self.num_train_timesteps + sigmas = mx.concatenate([sigmas, mx.array([sigma_last])]).astype(mx.float32) + + self.sigmas = sigmas + self.timesteps = timesteps.astype(mx.int32) + self.num_inference_steps = len(timesteps) + self.model_outputs = [None] * self.solver_order + self.lower_order_nums = 0 + self.last_sample = None + self._step_index = None + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + def convert_model_output(self, model_output, sample): + sigma_t = self.sigmas[self.step_index] + if self.predict_x0: + return sample - sigma_t * model_output + else: + return sample - (1 - sigma_t) * model_output + + def multistep_uni_p_bh_update(self, model_output, sample, order): + """Predictor step of the UniPC multistep solver. + + Key variables: + rks: Ratios of lambda differences between past and current steps + D1s: First-order finite differences of model outputs + R, b: Linear system for polynomial coefficient computation + h_phi_k: Exponential integrator phi functions + B_h: Scale factor -- expm1(h) for bh2 solver type + rhos_p: Polynomial coefficients from solving R*rhos = b + """ + model_output_list = self.model_outputs + m0 = model_output_list[-1] + x = sample + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = _lambda64(alpha_t, sigma_t) + lambda_s0 = _lambda64(alpha_s0, sigma_s0) + h = lambda_t - lambda_s0 + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = _lambda64(alpha_si, sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(mx.array(1.0, dtype=mx.float32)) + rks = mx.stack(rks) + + R = [] + b = [] + hh = -h if self.predict_x0 else h + h_phi_1 = mx.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + factorial_i = 1 + + if self.solver_type == "bh1": + B_h = hh + else: + B_h = mx.expm1(hh) + + for i in range(1, order + 1): + R.append(mx.power(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = mx.stack(R) + b = mx.stack(b) + + if len(D1s) > 0: + D1s = mx.stack(D1s, axis=1) + if order == 2: + rhos_p = mx.array([0.5], dtype=x.dtype) + else: + # Run on CPU for numerical stability (float64 not supported on Metal GPU), + # matching the reference implementation. + with mx.stream(mx.cpu): + rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1]).astype(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = mx.sum( + rhos_p.reshape(-1, *([1] * (D1s.ndim - 1))) * D1s, axis=1 + ) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = mx.sum( + rhos_p.reshape(-1, *([1] * (D1s.ndim - 1))) * D1s, axis=1 + ) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + return x_t.astype(x.dtype) + + def multistep_uni_c_bh_update( + self, this_model_output, last_sample, this_sample, order + ): + model_output_list = self.model_outputs + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = _lambda64(alpha_t, sigma_t) + lambda_s0 = _lambda64(alpha_s0, sigma_s0) + h = lambda_t - lambda_s0 + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = _lambda64(alpha_si, sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(mx.array(1.0, dtype=mx.float32)) + rks = mx.stack(rks) + + R = [] + b = [] + hh = -h if self.predict_x0 else h + h_phi_1 = mx.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + factorial_i = 1 + + if self.solver_type == "bh1": + B_h = hh + else: + B_h = mx.expm1(hh) + + for i in range(1, order + 1): + R.append(mx.power(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = mx.stack(R) + b = mx.stack(b) + + if len(D1s) > 0: + D1s = mx.stack(D1s, axis=1) + else: + D1s = None + + if order == 1: + rhos_c = mx.array([0.5], dtype=x.dtype) + else: + # Run on CPU for numerical stability (float64 not supported on Metal GPU), + # matching the reference implementation. + with mx.stream(mx.cpu): + rhos_c = mx.linalg.solve(R, b).astype(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = mx.sum( + rhos_c[:-1].reshape(-1, *([1] * (D1s.ndim - 1))) * D1s, axis=1 + ) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = mx.sum( + rhos_c[:-1].reshape(-1, *([1] * (D1s.ndim - 1))) * D1s, axis=1 + ) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + + return x_t.astype(x.dtype) + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + if isinstance(timestep, mx.array): + timestep_val = timestep + else: + timestep_val = mx.array(timestep, dtype=schedule_timesteps.dtype) + diff = mx.abs(schedule_timesteps - timestep_val) + first_idx = mx.argmin(diff) + num_matches = int((diff == 0).sum().item()) + if num_matches > 1: + return int(first_idx.item()) + 1 + return int(first_idx.item()) + + def _init_step_index(self, timestep): + self._step_index = self.index_for_timestep(timestep) + + def step(self, model_output, timestep, sample): + if self.num_inference_steps is None: + raise ValueError("Call set_timesteps before step()") + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.lower_order_final: + this_order = min(self.solver_order, len(self.timesteps) - self.step_index) + else: + this_order = self.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + self._step_index += 1 + return prev_sample + + +@partial(mx.compile, shapeless=True) +def _euler_step(model_output, sample, sigma, sigma_next): + return sample + model_output * (sigma_next - sigma) + + +class FlowEulerDiscreteScheduler: + """Simple Euler flow-matching scheduler for step-distilled models. + + Unlike UniPC, this uses a single-step Euler update matching how + step-distilled models were trained. Timestep selection uses indexed + positions from the full 1000-step schedule (via denoising_step_list) + rather than linear interpolation. + """ + + def __init__(self, num_train_timesteps=1000): + self.num_train_timesteps = num_train_timesteps + self.timesteps = None + self.sigmas = None + self.num_inference_steps = 0 + + def set_timesteps(self, denoising_step_list, shift=5.0): + """Build schedule by indexing into the full shifted schedule. + + Args: + denoising_step_list: e.g. [1000, 750, 500, 250]. Each value V + maps to index (num_train_timesteps - V) in the shifted schedule. + shift: Noise schedule shift factor (default 5.0 for distilled). + """ + sigmas = mx.linspace(1.0, 0.0, self.num_train_timesteps + 1)[:-1] + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + timesteps = sigmas * self.num_train_timesteps + + indices = mx.array([self.num_train_timesteps - x for x in denoising_step_list]) + self.sigmas = sigmas[indices].astype(mx.float32) + self.timesteps = timesteps[indices].astype(mx.float32) + self.num_inference_steps = len(denoising_step_list) + + def step(self, model_output, timestep, sample): + """Euler flow-matching update: x_new = x + f * (sigma_next - sigma).""" + t_val = timestep.item() if isinstance(timestep, mx.array) else timestep + step_index = int( + mx.argmin(mx.abs(self.timesteps.astype(mx.float32) - t_val)).item() + ) + + sigma = self.sigmas[step_index] + sigma_next = ( + self.sigmas[step_index + 1] + if step_index < self.num_inference_steps - 1 + else mx.array(0.0) + ) + return _euler_step( + model_output.astype(mx.float32), + sample.astype(mx.float32), + sigma, + sigma_next, + ).astype(sample.dtype) diff --git a/video/wan2.1/wan/t5.py b/video/wan2.1/wan/t5.py new file mode 100644 index 000000000..8543f6b0b --- /dev/null +++ b/video/wan2.1/wan/t5.py @@ -0,0 +1,190 @@ +# Copyright © 2026 Apple Inc. + +""" +T5 text encoder for Wan2.1. + +UMT5-XXL encoder (4096 dim, 24 layers, 64 heads) with gated GELU FFN +and per-layer relative position embeddings. +""" + +import math +import re +from typing import Dict, Optional + +import mlx.core as mx +import mlx.nn as nn +from einops import rearrange + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional=True, max_dist=128): + super().__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + self.embedding = nn.Embedding(num_buckets, num_heads) + + def _relative_position_bucket(self, rel_pos): + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets + rel_pos = mx.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32) + rel_pos = -mx.minimum(rel_pos, mx.zeros_like(rel_pos)) + + max_exact = num_buckets // 2 + is_small = rel_pos < max_exact + scale = (num_buckets - max_exact) / math.log(self.max_dist / max_exact) + rel_pos_large = max_exact + ( + mx.log(rel_pos.astype(mx.float32) / max_exact) * scale + ).astype(mx.int32) + rel_pos_large = mx.minimum(rel_pos_large, num_buckets - 1) + rel_buckets = rel_buckets + mx.where(is_small, rel_pos, rel_pos_large) + return rel_buckets + + def __call__(self, lq, lk): + query_pos = mx.arange(lq)[:, None] + key_pos = mx.arange(lk)[None, :] + rel_pos = key_pos - query_pos + rel_buckets = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_buckets) + return rel_pos_embeds.transpose(2, 0, 1)[None, :, :, :] + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads): + super().__init__() + assert dim_attn % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + + def __call__(self, x, context=None, mask=None, pos_bias=None): + context = x if context is None else context + b = x.shape[0] + n, c = self.num_heads, self.head_dim + + q = rearrange(self.q(x), "b s (n c) -> b n s c", n=n) + k = rearrange(self.k(context), "b s (n c) -> b n s c", n=n) + v = rearrange(self.v(context), "b s (n c) -> b n s c", n=n) + + attn_bias = mx.zeros((b, n, q.shape[2], k.shape[2]), dtype=x.dtype) + if pos_bias is not None: + attn_bias = attn_bias + pos_bias + if mask is not None: + if mask.ndim == 2: + mask = mask[:, None, None, :] + else: + mask = mask[:, None, :, :] + attn_bias = mx.where(mask == 0, -1e9, attn_bias) + + # T5 does NOT use sqrt(d) scaling + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=attn_bias) + out = rearrange(out, "b n s c -> b s (n c)") + return self.o(out) + + +class T5FeedForward(nn.Module): + def __init__(self, dim, dim_ffn): + super().__init__() + self.gate = nn.Linear(dim, dim_ffn, bias=False) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + + def __call__(self, x): + return self.fc2(self.fc1(x) * nn.gelu_approx(self.gate(x))) + + +class T5SelfAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True): + super().__init__() + self.shared_pos = shared_pos + self.norm1 = nn.RMSNorm(dim, eps=1e-6) + self.attn = T5Attention(dim, dim_attn, num_heads) + self.norm2 = nn.RMSNorm(dim, eps=1e-6) + self.ffn = T5FeedForward(dim, dim_ffn) + self.pos_embedding = ( + None + if shared_pos + else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + ) + + def __call__(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1]) + x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e) + x = x + self.ffn(self.norm2(x)) + return x + + +class T5Encoder(nn.Module): + def __init__( + self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + ): + super().__init__() + self.dim = dim + self.num_layers = num_layers + self.shared_pos = shared_pos + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = ( + T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + if shared_pos + else None + ) + self.blocks = [ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos) + for _ in range(num_layers) + ] + self.norm = nn.RMSNorm(dim, eps=1e-6) + + def __call__(self, ids, mask=None): + x = self.token_embedding(ids) + seq_len = x.shape[1] + e = self.pos_embedding(seq_len, seq_len) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask=mask, pos_bias=e) + x = self.norm(x) + return x + + @staticmethod + def sanitize(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Remap PyTorch T5 keys to MLX format.""" + remapped = {} + for key, value in weights.items(): + new_key = key + if new_key.startswith("model."): + new_key = new_key[6:] + if "ffn.gate.1" in new_key: + continue + if "dropout" in new_key: + continue + new_key = re.sub(r"ffn\.gate\.0\.", "ffn.gate.", new_key) + remapped[new_key] = value + return remapped + + +def create_umt5_xxl_encoder() -> T5Encoder: + return T5Encoder( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + num_layers=24, + num_buckets=32, + shared_pos=False, + ) diff --git a/video/wan2.1/wan/tokenizers.py b/video/wan2.1/wan/tokenizers.py new file mode 100644 index 000000000..88a9f0050 --- /dev/null +++ b/video/wan2.1/wan/tokenizers.py @@ -0,0 +1,47 @@ +# Copyright © 2026 Apple Inc. + +""" +T5 tokenizer for Wan2.1. + +Uses the `tokenizers` library (HuggingFace tokenizers) to load tokenizer.json, +avoiding PyTorch and sentencepiece dependencies. +""" + +from typing import Any, Dict + +import mlx.core as mx + + +class T5Tokenizer: + """Pure tokenizer wrapper using HuggingFace tokenizers library.""" + + def __init__(self, tokenizer_path: str): + from tokenizers import Tokenizer + + self._tokenizer = Tokenizer.from_file(tokenizer_path) + self.pad_token_id = self._tokenizer.token_to_id("") or 0 + + def __call__( + self, + text: str, + max_length: int = 512, + padding: str = "max_length", + truncation: bool = True, + ) -> Dict[str, Any]: + encoded = self._tokenizer.encode(text) + input_ids = encoded.ids + + if truncation and len(input_ids) > max_length: + input_ids = input_ids[:max_length] + + attention_mask = [1] * len(input_ids) + + if padding == "max_length" and len(input_ids) < max_length: + pad_length = max_length - len(input_ids) + input_ids = input_ids + [self.pad_token_id] * pad_length + attention_mask = attention_mask + [0] * pad_length + + return { + "input_ids": mx.array([input_ids], dtype=mx.int32), + "attention_mask": mx.array([attention_mask], dtype=mx.int32), + } diff --git a/video/wan2.1/wan/utils.py b/video/wan2.1/wan/utils.py new file mode 100644 index 000000000..6e650d633 --- /dev/null +++ b/video/wan2.1/wan/utils.py @@ -0,0 +1,240 @@ +# Copyright © 2026 Apple Inc. + +""" +Utility functions for Wan2.1 pipeline. + +Weight loading, HF Hub downloading, video saving. +""" + +import json +import os +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional + +import mlx.core as mx +import numpy as np + +from .model import WanModel +from .t5 import T5Encoder, create_umt5_xxl_encoder +from .tokenizers import T5Tokenizer +from .vae import WanVAE + + +@dataclass +class ModelSpec: + repo_id: str + repo_dit: str + repo_vae: str + repo_t5: str + dit_params: Dict + repo_tokenizer: str = "google/umt5-xxl/tokenizer.json" + repo_clip: Optional[str] = None + ckpt_path: Optional[str] = None + + +configs = { + "t2v-1.3B": ModelSpec( + repo_id="Wan-AI/Wan2.1-T2V-1.3B", + repo_dit="diffusion_pytorch_model.safetensors", + repo_vae="Wan2.1_VAE.pth", + repo_t5="models_t5_umt5-xxl-enc-bf16.pth", + dit_params={"dim": 1536, "ffn_dim": 8960, "num_heads": 12, "num_layers": 30}, + ckpt_path=os.getenv("WAN_T2V_1_3B"), + ), + "t2v-14B": ModelSpec( + repo_id="Wan-AI/Wan2.1-T2V-14B", + repo_dit="diffusion_pytorch_model.safetensors.index.json", + repo_vae="Wan2.1_VAE.pth", + repo_t5="models_t5_umt5-xxl-enc-bf16.pth", + dit_params={"dim": 5120, "ffn_dim": 13824, "num_heads": 40, "num_layers": 40}, + ckpt_path=os.getenv("WAN_T2V_14B"), + ), + "i2v-14B": ModelSpec( + repo_id="Wan-AI/Wan2.1-I2V-14B-480P", + repo_dit="diffusion_pytorch_model.safetensors.index.json", + repo_vae="Wan2.1_VAE.pth", + repo_t5="models_t5_umt5-xxl-enc-bf16.pth", + repo_clip="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", + dit_params={ + "dim": 5120, + "ffn_dim": 13824, + "num_heads": 40, + "num_layers": 40, + "model_type": "i2v", + "in_dim": 36, + }, + ckpt_path=os.getenv("WAN_I2V_14B"), + ), +} + + +def _hf_download(repo_id: str, filename: str) -> str: + from huggingface_hub import hf_hub_download + + return hf_hub_download(repo_id=repo_id, filename=filename) + + +def _load_weights(path: str) -> dict: + """Unified loader for safetensors, sharded index, and .pth files.""" + assert os.path.isfile(path), f"Weights file at {path} does not exist" + if path.endswith(".index.json"): + weight_dir = os.path.dirname(path) + with open(path) as f: + index = json.load(f) + weight_files = set(index["weight_map"].values()) + # Ensure all shards are downloaded (for HF Hub paths) + for wf in weight_files: + shard_path = os.path.join(weight_dir, wf) + if not os.path.exists(shard_path): + # Infer repo_id from HF cache path structure + # .../models--Org--Repo/snapshots/hash/file + parts = Path(path).parts + for i, p in enumerate(parts): + if p.startswith("models--"): + repo_id = p.replace("models--", "").replace("--", "/") + _hf_download(repo_id, wf) + break + weights = {} + for wf in weight_files: + weights.update(mx.load(os.path.join(weight_dir, wf))) + return weights + elif path.endswith(".pth"): + import torch + + sd = torch.load(path, map_location="cpu", weights_only=True) + weights = {k: mx.array(v.float().numpy()) for k, v in sd.items()} + + del torch + + return weights + else: + return mx.load(path) + + +def load_dit(name: str, checkpoint: Optional[str] = None) -> WanModel: + """Load DiT model with weights from HF Hub.""" + spec = configs[name] + model = WanModel(**spec.dit_params) + ckpt_path = checkpoint or spec.ckpt_path + if ckpt_path is None: + ckpt_path = _hf_download(spec.repo_id, spec.repo_dit) + weights = _load_weights(ckpt_path) + weights = WanModel.sanitize(weights) + model.load_weights(list(weights.items()), strict=True) + return model + + +def load_vae(name: str) -> WanVAE: + """Load VAE decoder with weights from HF Hub.""" + spec = configs[name] + vae = WanVAE() + ckpt_path = _hf_download(spec.repo_id, spec.repo_vae) + weights = _load_weights(ckpt_path) + weights = WanVAE.sanitize(weights) + vae.load_weights(list(weights.items()), strict=False) + return vae + + +def load_t5(name: str) -> T5Encoder: + """Load T5 encoder with weights from HF Hub.""" + spec = configs[name] + t5 = create_umt5_xxl_encoder() + weight_path = _hf_download(spec.repo_id, spec.repo_t5) + weights = _load_weights(weight_path) + weights = T5Encoder.sanitize(weights) + t5.load_weights(list(weights.items()), strict=True) + return t5 + + +def load_clip(name: str): + """Load CLIP vision encoder with weights from HF Hub.""" + from .clip import CLIPVisionEncoder + + spec = configs[name] + if spec.repo_clip is None: + raise ValueError(f"Model {name} does not have a CLIP config") + clip = CLIPVisionEncoder() + weight_path = _hf_download(spec.repo_id, spec.repo_clip) + weights = _load_weights(weight_path) + weights = CLIPVisionEncoder.sanitize(weights) + clip.load_weights(list(weights.items()), strict=True) + return clip + + +def load_t5_tokenizer(name: str) -> T5Tokenizer: + """Load T5 tokenizer from HF Hub.""" + spec = configs[name] + tok_path = _hf_download(spec.repo_id, spec.repo_tokenizer) + return T5Tokenizer(tok_path) + + +def save_video( + frames: mx.array, + output_path: str, + fps: int = 16, +) -> bool: + """ + Save video frames to file using ffmpeg. + + Args: + frames: Video tensor [T, H, W, C] (channels-last) in [-1, 1] or [0, 1] + output_path: Output file path + fps: Frames per second + + Returns: + True if successful + """ + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + + if frames.ndim == 5: + frames = frames[0] + + # Convert from [-1, 1] to [0, 1] + mx.eval(frames) + if frames.min().item() < 0: + frames = (frames + 1.0) / 2.0 + + frames_np = np.array(mx.clip(frames * 255, 0, 255).astype(mx.uint8)) + T, H, W, C = frames_np.shape + + print(f"Saving {T} frames ({W}x{H}) to {output_path}") + + cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-s", + f"{W}x{H}", + "-pix_fmt", + "rgb24", + "-r", + str(fps), + "-i", + "-", + "-c:v", + "libx264", + "-pix_fmt", + "yuv420p", + "-crf", + "18", + "-preset", + "fast", + output_path, + ] + + try: + process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) + _, stderr = process.communicate(input=frames_np.tobytes()) + if process.returncode != 0: + print(f"FFmpeg error: {stderr.decode()}") + return False + print(f"Video saved to {output_path}") + return True + except Exception as e: + print(f"Error saving video: {e}") + return False diff --git a/video/wan2.1/wan/vae.py b/video/wan2.1/wan/vae.py new file mode 100644 index 000000000..500631526 --- /dev/null +++ b/video/wan2.1/wan/vae.py @@ -0,0 +1,518 @@ +# Copyright © 2026 Apple Inc. + +""" +Wan2.1 VAE encoder and decoder. + +Encodes video frames to latents and decodes latents to video frames +using chunked processing with causal temporal caching. +""" + +import re +from typing import Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn + +from .vae_layers import ( + AttentionBlock, + CausalConv3d, + Resample, + ResidualBlock, + create_cache_entry, +) + + +class Decoder3d(nn.Module): + """ + VAE Decoder for video generation. + + Input: [B, T, H/8, W/8, z_dim] (channels-last) + Output: [B, T*4, H, W, 3] + """ + + def __init__( + self, + dim: int = 96, + z_dim: int = 16, + dim_mult: Optional[List[int]] = None, + num_res_blocks: int = 2, + attn_scales: Optional[List[float]] = None, + temporal_upsample: Optional[List[bool]] = None, + ): + super().__init__() + if dim_mult is None: + dim_mult = [1, 2, 4, 4] + if attn_scales is None: + attn_scales = [] + if temporal_upsample is None: + temporal_upsample = [True, True, False] + + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.temporal_upsample = temporal_upsample + + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + self.middle_res1 = ResidualBlock(dims[0], dims[0]) + self.middle_attn = AttentionBlock(dims[0]) + self.middle_res2 = ResidualBlock(dims[0], dims[0]) + + # Build upsample stages as nested lists + self.upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + stage = [] + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for j in range(num_res_blocks + 1): + stage.append(ResidualBlock(in_dim, out_dim)) + if scale in attn_scales: + stage.append(AttentionBlock(out_dim)) + in_dim = out_dim + if i != len(dim_mult) - 1: + mode = "upsample3d" if temporal_upsample[i] else "upsample2d" + stage.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples.append(stage) + + self.head_norm = nn.RMSNorm(dims[-1], eps=1e-12) + self.head_conv = CausalConv3d(dims[-1], 3, 3, padding=1) + + # Count temporal cache slots from architecture + n = 1 + 2 + 2 # conv1, middle_res1, middle_res2 + for stage in self.upsamples: + for layer in stage: + if isinstance(layer, ResidualBlock): + n += 2 + elif isinstance(layer, Resample) and hasattr(layer, "time_conv"): + n += 1 + n += 1 # head_conv + self.num_cache_slots = n + + def __call__(self, x, feat_cache): + cache_idx = 0 + new_cache = [] + + cache_input = x + x = self.conv1(x, feat_cache[cache_idx]) + new_cache.append(create_cache_entry(cache_input, feat_cache[cache_idx])) + cache_idx += 1 + + x, c1, c2 = self.middle_res1( + x, feat_cache[cache_idx], feat_cache[cache_idx + 1] + ) + new_cache.append(c1) + new_cache.append(c2) + cache_idx += 2 + + x = self.middle_attn(x) + + x, c1, c2 = self.middle_res2( + x, feat_cache[cache_idx], feat_cache[cache_idx + 1] + ) + new_cache.append(c1) + new_cache.append(c2) + cache_idx += 2 + + for stage in self.upsamples: + for layer in stage: + if isinstance(layer, ResidualBlock): + x, c1, c2 = layer( + x, feat_cache[cache_idx], feat_cache[cache_idx + 1] + ) + new_cache.append(c1) + new_cache.append(c2) + cache_idx += 2 + elif isinstance(layer, AttentionBlock): + x = layer(x) + elif isinstance(layer, Resample): + x, c = layer(x, feat_cache[cache_idx]) + if c is not None: + new_cache.append(c) + cache_idx += 1 + + x = self.head_norm(x) + x = nn.silu(x) + cache_input = x + x = self.head_conv(x, feat_cache[cache_idx]) + new_cache.append(create_cache_entry(cache_input, feat_cache[cache_idx])) + cache_idx += 1 + + return x, new_cache + + +class Encoder3d(nn.Module): + """ + VAE Encoder for video generation. + + Input: [B, T, H, W, 3] (channels-last) + Output: [B, T', H/8, W/8, z_dim*2] + """ + + def __init__( + self, + dim: int = 96, + z_dim: int = 16, + dim_mult: Optional[List[int]] = None, + num_res_blocks: int = 2, + attn_scales: Optional[List[float]] = None, + temporal_downsample: Optional[List[bool]] = None, + ): + super().__init__() + if dim_mult is None: + dim_mult = [1, 2, 4, 4] + if attn_scales is None: + attn_scales = [] + if temporal_downsample is None: + temporal_downsample = [False, True, True] + + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.temporal_downsample = temporal_downsample + + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # Build downsample stages as nested lists + self.downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + stage = [] + for j in range(num_res_blocks): + stage.append(ResidualBlock(in_dim, out_dim)) + if scale in attn_scales: + stage.append(AttentionBlock(out_dim)) + in_dim = out_dim + if i != len(dim_mult) - 1: + mode = "downsample3d" if temporal_downsample[i] else "downsample2d" + stage.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples.append(stage) + + self.middle_res1 = ResidualBlock(dims[-1], dims[-1]) + self.middle_attn = AttentionBlock(dims[-1]) + self.middle_res2 = ResidualBlock(dims[-1], dims[-1]) + + self.head_norm = nn.RMSNorm(dims[-1], eps=1e-12) + self.head_conv = CausalConv3d(dims[-1], z_dim * 2, 3, padding=1) + + # Count temporal cache slots from architecture + n = 1 # conv1 + for stage in self.downsamples: + for layer in stage: + if isinstance(layer, ResidualBlock): + n += 2 + elif isinstance(layer, Resample) and hasattr(layer, "time_conv"): + n += 1 + n += 2 + 2 + 1 # middle_res1, middle_res2, head_conv + self.num_cache_slots = n + + def __call__(self, x, feat_cache): + cache_idx = 0 + new_cache = [] + + cache_input = x + x = self.conv1(x, feat_cache[cache_idx]) + new_cache.append(create_cache_entry(cache_input, feat_cache[cache_idx])) + cache_idx += 1 + + for stage in self.downsamples: + for layer in stage: + if isinstance(layer, ResidualBlock): + x, c1, c2 = layer( + x, feat_cache[cache_idx], feat_cache[cache_idx + 1] + ) + new_cache.append(c1) + new_cache.append(c2) + cache_idx += 2 + elif isinstance(layer, AttentionBlock): + x = layer(x) + elif isinstance(layer, Resample): + x, c = layer(x, feat_cache[cache_idx]) + if c is not None: + new_cache.append(c) + cache_idx += 1 + + x, c1, c2 = self.middle_res1( + x, feat_cache[cache_idx], feat_cache[cache_idx + 1] + ) + new_cache.append(c1) + new_cache.append(c2) + cache_idx += 2 + + x = self.middle_attn(x) + + x, c1, c2 = self.middle_res2( + x, feat_cache[cache_idx], feat_cache[cache_idx + 1] + ) + new_cache.append(c1) + new_cache.append(c2) + cache_idx += 2 + + x = self.head_norm(x) + x = nn.silu(x) + cache_input = x + x = self.head_conv(x, feat_cache[cache_idx]) + new_cache.append(create_cache_entry(cache_input, feat_cache[cache_idx])) + cache_idx += 1 + + return x, new_cache + + +class WanVAE(nn.Module): + """ + High-level VAE wrapper for Wan2.1. + + Encode: [F, H, W, C] video -> [F', H/8, W/8, z_dim] latent + Decode: [F, H, W, C] latent -> [F*4, H*8, W*8, 3] video clamped to [-1, 1] + """ + + def __init__(self): + super().__init__() + self.encoder = Encoder3d() + self.conv1 = CausalConv3d(32, 32, 1) + self.decoder = Decoder3d() + self.conv2 = CausalConv3d(16, 16, 1) + + self.mean = mx.array( + [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ] + ) + self.std = mx.array( + [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ] + ) + self.z_dim = 16 + # Pre-compile for the frame-by-frame loop: avoids recompiling each frame. + self._compiled_decode = mx.compile(self.decoder.__call__) + self._compiled_encode = mx.compile(self.encoder.__call__) + + def decode(self, z: mx.array) -> mx.array: + """ + Decode latent to video. + + Args: + z: Latent tensor [F, H, W, C] (channels-last) + + Returns: + Video tensor [F, H, W, C] clamped to [-1, 1] (channels-last) + """ + # Add batch dim: [F, H, W, C] -> [1, F, H, W, C] + z = z[None] + + # Unscale latents + scale = 1.0 / self.std + z = z / scale.reshape(1, 1, 1, 1, self.z_dim) + self.mean.reshape( + 1, 1, 1, 1, self.z_dim + ) + + # Pre-decoder conv + x = self.conv2(z) + + # Decode one frame at a time. mx.eval per frame releases intermediates, keeping memory bounded. + num_frames = x.shape[1] + feat_cache = [None] * self.decoder.num_cache_slots + outputs = [] + + for i in range(num_frames): + frame = x[:, i : i + 1, :, :, :] + out_frame, feat_cache = self._compiled_decode(frame, feat_cache) + mx.eval(out_frame) + outputs.append(out_frame) + + out = mx.concatenate(outputs, axis=1) + out = mx.clip(out, -1.0, 1.0) + + # Remove batch dim: [1, F, H, W, C] -> [F, H, W, C] + return out[0] + + def encode(self, x: mx.array) -> mx.array: + """ + Encode video to latent. + + Args: + x: Video tensor [F, H, W, C] (channels-last) + + Returns: + Latent tensor [F', H/8, W/8, C] (channels-last) + """ + # Add batch dim: [F, H, W, C] -> [1, F, H, W, C] + x = x[None] + + num_frames = x.shape[1] + feat_cache = [None] * self.encoder.num_cache_slots + outputs = [] + + # First chunk is 1 frame (causal init), subsequent chunks are 4 frames (matching VAE temporal stride). + i = 0 + chunk_idx = 0 + while i < num_frames: + if chunk_idx == 0: + chunk = x[:, i : i + 1, :, :, :] + i += 1 + else: + chunk = x[:, i : i + 4, :, :, :] + i += 4 + + out_chunk, feat_cache = self._compiled_encode(chunk, feat_cache) + mx.eval(out_chunk) + outputs.append(out_chunk) + chunk_idx += 1 + + out = mx.concatenate(outputs, axis=1) + + # Post-encoder conv and extract mu + out = self.conv1(out) + mu = out[:, :, :, :, : self.z_dim] + + # Scale: (mu - mean) * (1/std) + scale = 1.0 / self.std + mu = (mu - self.mean.reshape(1, 1, 1, 1, self.z_dim)) * scale.reshape( + 1, 1, 1, 1, self.z_dim + ) + + # Remove batch dim: [1, F', H', W', C] -> [F', H', W', C] + return mu[0] + + @staticmethod + def sanitize(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Remap PyTorch VAE keys to MLX format.""" + remapped = {} + for key, value in weights.items(): + new_key = key + + # Transpose convolution weights + if "weight" in new_key: + if len(value.shape) == 5: + value = mx.transpose(value, (0, 2, 3, 4, 1)) + elif len(value.shape) == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + + new_key = new_key.replace(".gamma", ".weight") + new_key = new_key.replace("decoder.middle.0.", "decoder.middle_res1.") + new_key = new_key.replace("decoder.middle.1.", "decoder.middle_attn.") + new_key = new_key.replace("decoder.middle.2.", "decoder.middle_res2.") + new_key = new_key.replace("decoder.head.0.", "decoder.head_norm.") + new_key = new_key.replace("decoder.head.2.", "decoder.head_conv.") + new_key = new_key.replace("encoder.middle.0.", "encoder.middle_res1.") + new_key = new_key.replace("encoder.middle.1.", "encoder.middle_attn.") + new_key = new_key.replace("encoder.middle.2.", "encoder.middle_res2.") + new_key = new_key.replace("encoder.head.0.", "encoder.head_norm.") + new_key = new_key.replace("encoder.head.2.", "encoder.head_conv.") + + if "decoder.upsamples." in new_key: + new_key = _map_vae_upsample_key(new_key) + if "encoder.downsamples." in new_key: + new_key = _map_vae_downsample_key(new_key) + + new_key = re.sub(r"\.residual\.0\.", ".norm1.", new_key) + new_key = re.sub(r"\.residual\.2\.", ".conv1.", new_key) + new_key = re.sub(r"\.residual\.3\.", ".norm2.", new_key) + new_key = re.sub(r"\.residual\.6\.", ".conv2.", new_key) + + # Resample conv: .resample.1. -> .conv. + new_key = re.sub(r"\.resample\.1\.", ".conv.", new_key) + + # Squeeze 1x1 conv weights to 2D for nn.Linear (to_qkv, proj) + if ("to_qkv" in new_key or "proj" in new_key) and "weight" in new_key: + if ( + len(value.shape) == 4 + and value.shape[1] == 1 + and value.shape[2] == 1 + ): + value = value.reshape(value.shape[0], value.shape[3]) + + # Squeeze norm weights to 1D (required — nn.RMSNorm expects 1D) + if "norm" in new_key and "weight" in new_key: + if len(value.shape) > 1: + value = mx.squeeze(value) + + remapped[new_key] = value + return remapped + + +def _map_vae_upsample_key(key: str) -> str: + match = re.match(r"decoder\.upsamples\.(\d+)\.(.*)", key) + if not match: + return key + + layer_idx = int(match.group(1)) + rest = match.group(2) + + # Decoder stages: (num_res_blocks+1) ResBlocks + 1 Resample each, except last (no Resample). + # Assumes attn_scales=[] (Wan2.1 default — no AttentionBlocks in stages). + num_res_blocks, num_stages = 2, 4 + stage_sizes = [num_res_blocks + 2] * (num_stages - 1) + [num_res_blocks + 1] + stage = 0 + local_idx = layer_idx + + for s, size in enumerate(stage_sizes): + if local_idx < size: + stage = s + break + local_idx -= size + + return f"decoder.upsamples.{stage}.{local_idx}.{rest}" + + +def _map_vae_downsample_key(key: str) -> str: + match = re.match(r"encoder\.downsamples\.(\d+)\.(.*)", key) + if not match: + return key + + layer_idx = int(match.group(1)) + rest = match.group(2) + + # Encoder stages: num_res_blocks ResBlocks + 1 Resample each, except last (no Resample). + # Assumes attn_scales=[] (Wan2.1 default — no AttentionBlocks in stages). + num_res_blocks, num_stages = 2, 4 + stage_sizes = [num_res_blocks + 1] * (num_stages - 1) + [num_res_blocks] + stage = 0 + local_idx = layer_idx + + for s, size in enumerate(stage_sizes): + if local_idx < size: + stage = s + break + local_idx -= size + + return f"encoder.downsamples.{stage}.{local_idx}.{rest}" diff --git a/video/wan2.1/wan/vae_layers.py b/video/wan2.1/wan/vae_layers.py new file mode 100644 index 000000000..194315cdd --- /dev/null +++ b/video/wan2.1/wan/vae_layers.py @@ -0,0 +1,229 @@ +# Copyright © 2026 Apple Inc. + +""" +Building blocks for Wan2.1 VAE. + +All layers use channels-last format (NTHWC) as required by MLX. +""" + +import mlx.core as mx +import mlx.nn as nn + +# Temporal cache depth: 2 frames for causal conv with kernel_size=3 along time. +CACHE_T = 2 + + +def _normalize_tuple(value, n): + if isinstance(value, int): + return (value,) * n + return tuple(value) + + +def create_cache_entry(x, existing_cache=None): + """Build temporal cache from the last CACHE_T frames of x, merging with existing cache.""" + t = x.shape[1] + if t >= CACHE_T: + return x[:, -CACHE_T:, :, :, :] + else: + cache_x = x[:, -t:, :, :, :] + if existing_cache is not None: + old_frames = existing_cache[:, -(CACHE_T - t) :, :, :, :] + return mx.concatenate([old_frames, cache_x], axis=1) + else: + pad_t = CACHE_T - t + zeros = mx.zeros((x.shape[0], pad_t, *x.shape[2:]), dtype=x.dtype) + return mx.concatenate([zeros, cache_x], axis=1) + + +class CausalConv3d(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _normalize_tuple(kernel_size, 3) + self.stride = _normalize_tuple(stride, 3) + self.padding = _normalize_tuple(padding, 3) + self._temporal_pad = self.padding[0] * 2 + self._spatial_pad_h = self.padding[1] + self._spatial_pad_w = self.padding[2] + + scale = ( + 1.0 + / ( + in_channels + * self.kernel_size[0] + * self.kernel_size[1] + * self.kernel_size[2] + ) + ** 0.5 + ) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, *self.kernel_size, in_channels), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + def __call__(self, x, cache_x=None): + # Causal temporal padding (left-only), then symmetric spatial padding, then conv with padding=0. + temporal_pad = self._temporal_pad + if cache_x is not None and self._temporal_pad > 0: + x = mx.concatenate([cache_x, x], axis=1) + temporal_pad = max(0, self._temporal_pad - cache_x.shape[1]) + + if temporal_pad > 0: + x = mx.pad(x, [(0, 0), (temporal_pad, 0), (0, 0), (0, 0), (0, 0)]) + + if self._spatial_pad_h > 0 or self._spatial_pad_w > 0: + x = mx.pad( + x, + [ + (0, 0), + (0, 0), + (self._spatial_pad_h, self._spatial_pad_h), + (self._spatial_pad_w, self._spatial_pad_w), + (0, 0), + ], + ) + + y = mx.conv3d(x, self.weight, stride=self.stride, padding=0) + if "bias" in self: + y = y + self.bias + return y + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ( + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + if mode == "upsample2d": + self.upsample = nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest") + self.conv = nn.Conv2d( + dim, dim // 2, kernel_size=3, stride=1, padding=0, bias=True + ) + elif mode == "upsample3d": + self.upsample = nn.Upsample(scale_factor=(2.0, 2.0), mode="nearest") + self.conv = nn.Conv2d( + dim, dim // 2, kernel_size=3, stride=1, padding=0, bias=True + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode in ("downsample2d", "downsample3d"): + self.conv = nn.Conv2d( + dim, dim, kernel_size=3, stride=2, padding=0, bias=True + ) + if mode == "downsample3d": + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + + def __call__(self, x, cache=None): + b, t, h, w, c = x.shape + new_cache = None + + if self.mode == "upsample3d": + if cache is None: + new_cache = mx.zeros((b, CACHE_T, h, w, c), dtype=x.dtype) + else: + cache_input = x + x = self.time_conv(x, cache) + new_cache = create_cache_entry(cache_input, cache) + x = x.reshape(b, t, h, w, 2, c) + x = x.transpose(0, 1, 4, 2, 3, 5) + x = x.reshape(b, t * 2, h, w, c) + + t_out = x.shape[1] + c_out = x.shape[4] + x = x.reshape(b * t_out, x.shape[2], x.shape[3], c_out) + + if self.mode in ("upsample2d", "upsample3d"): + x = self.upsample(x) + x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)]) + x = self.conv(x) + elif self.mode in ("downsample2d", "downsample3d"): + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) + x = self.conv(x) + + x = x.reshape(b, t_out, x.shape[1], x.shape[2], x.shape[3]) + + if self.mode == "downsample3d": + if cache is None: + new_cache = x + else: + x_with_cache = mx.concatenate([cache[:, -1:, :, :, :], x], axis=1) + new_cache = x[:, -1:, :, :, :] + x = self.time_conv(x_with_cache, None) + + return x, new_cache + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.norm1 = nn.RMSNorm(in_dim, eps=1e-12) + self.conv1 = CausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = nn.RMSNorm(out_dim, eps=1e-12) + self.conv2 = CausalConv3d(out_dim, out_dim, 3, padding=1) + if in_dim != out_dim: + self.shortcut = CausalConv3d(in_dim, out_dim, 1) + else: + self.shortcut = None + + def __call__(self, x, cache1, cache2): + if self.shortcut is not None: + h = self.shortcut(x) + else: + h = x + + residual = self.norm1(x) + residual = nn.silu(residual) + cache_input = residual + residual = self.conv1(residual, cache1) + new_cache1 = create_cache_entry(cache_input, cache1) + + residual = self.norm2(residual) + residual = nn.silu(residual) + cache_input = residual + residual = self.conv2(residual, cache2) + new_cache2 = create_cache_entry(cache_input, cache2) + + return h + residual, new_cache1, new_cache2 + + +class AttentionBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + self.norm = nn.RMSNorm(dim, eps=1e-12) + self.to_qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def __call__(self, x): + identity = x + b, t, h, w, c = x.shape + x = x.reshape(b * t, h, w, c) + x = self.norm(x) + qkv = self.to_qkv(x) + qkv = qkv.reshape(b * t, h * w, 3, c) + q, k, v = qkv[:, :, 0, :], qkv[:, :, 1, :], qkv[:, :, 2, :] + q = q.reshape(b * t, 1, h * w, c) + k = k.reshape(b * t, 1, h * w, c) + v = v.reshape(b * t, 1, h * w, c) + scale = c**-0.5 + attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + attn = attn.squeeze(1).reshape(b * t, h, w, c) + out = self.proj(attn) + out = out.reshape(b, t, h, w, c) + return out + identity