Skip to content

Commit eb2be92

Browse files
committed
[MAX] Add Wan video generation examples and comparison benchmark
## Summary Add a standalone video generation example script for Wan T2V and I2V pipelines. ## Description - `simple_offline_video_generation.py`: end-to-end script for generating videos from text or image prompts - Supports all Wan model variants: 2.2-A14B (MoE), 2.1-14B, T2V and I2V - LoRA turbo support (e.g. Lightning 4-step) - Built-in profiling with component-level timing breakdown - Input images can be local files or URLs (downloaded at runtime) - Outputs MP4 video via `av` (PyAV) ## Validation (H200 140GB, 720p 81 frames) ```bash # T2V base (Wan2.2-A14B MoE, 720p, 40 steps) MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 \ ./bazelw run //max/examples/diffusion:simple_offline_video_generation -- \ --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ --negative-prompt "low quality" \ --height 720 --width 1280 --num-frames 81 \ --num-inference-steps 40 --guidance-scale 4.0 \ --guidance-scale-2 3.0 \ --output t2v_base.mp4 # T2V LoRA turbo (4 steps) MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 \ ./bazelw run //max/examples/diffusion:simple_offline_video_generation -- \ --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ --prompt "A cat playing piano" \ --height 720 --width 1280 --num-frames 81 \ --num-inference-steps 4 --guidance-scale 1.0 \ --lora-repo-id lightx2v/Wan2.2-Lightning \ --lora-subfolder Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0 \ --output t2v_lora.mp4 # I2V base (720p, 40 steps) MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 \ ./bazelw run //max/examples/diffusion:simple_offline_video_generation -- \ --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ --prompt "A cat surfing on a wave" \ --negative-prompt "low quality" \ --height 720 --width 1280 --num-frames 81 \ --num-inference-steps 40 --guidance-scale 4.0 \ --guidance-scale-2 3.0 \ --input-image https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B/resolve/main/examples/i2v_input.JPG \ --output i2v_base.mp4 # I2V LoRA turbo (4 steps) MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 \ ./bazelw run //max/examples/diffusion:simple_offline_video_generation -- \ --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ --prompt "A cat surfing on a wave" \ --height 720 --width 1280 --num-frames 81 \ --num-inference-steps 4 --guidance-scale 1.0 \ --lora-repo-id lightx2v/Wan2.2-Lightning \ --lora-subfolder Wan2.2-I2V-A14B-4steps-lora-rank64-Seko-V1 \ --input-image https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B/resolve/main/examples/i2v_input.JPG \ --output i2v_lora.mp4 ``` ## Dependencies Depends on all previous PRs: modular#6298modular#6303. ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code
1 parent 451c1f7 commit eb2be92

3 files changed

Lines changed: 766 additions & 0 deletions

File tree

max/examples/diffusion/BUILD.bazel

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,31 @@ modular_py_binary(
4545
requirement("torch"), # for test
4646
],
4747
)
48+
49+
modular_py_binary(
50+
name = "all_wan_model_speed_metric",
51+
srcs = ["all_wan_model_speed_metric.py"],
52+
tags = ["no-pydeps"],
53+
deps = [
54+
requirement("numpy"),
55+
],
56+
)
57+
58+
modular_py_binary(
59+
name = "simple_offline_video_generation",
60+
srcs = ["simple_offline_video_generation.py"],
61+
mojo_deps = ["//max:MOGGKernelAPI"],
62+
tags = ["no-pydeps"],
63+
deps = [
64+
":profiler",
65+
"//max/python/max/interfaces",
66+
"//max/python/max/pipelines",
67+
"//max/python/max/pipelines/core",
68+
"//max/python/max/pipelines/lib",
69+
requirement("av"),
70+
requirement("numpy"),
71+
requirement("pillow"),
72+
requirement("sentencepiece"),
73+
requirement("torch"),
74+
],
75+
)
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
#!/usr/bin/env python3
2+
# ===----------------------------------------------------------------------=== #
3+
# Copyright (c) 2026, Modular Inc. All rights reserved.
4+
#
5+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
6+
# https://llvm.org/LICENSE.txt
7+
# ===----------------------------------------------------------------------=== #
8+
9+
"""Wan video generation full metric: benchmarks all model variants via MAX.
10+
11+
Runs all Wan model variants (2.2/2.1, T2V/I2V) at 720p with
12+
multiple resolutions to verify symbolic seq_len recompilation behavior.
13+
14+
Usage:
15+
MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 \
16+
./bazelw run //max/examples/diffusion:full_metric
17+
18+
# Specific model only
19+
MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 \
20+
./bazelw run //max/examples/diffusion:full_metric -- \
21+
--model wan2.2-t2v-a14b
22+
"""
23+
24+
from __future__ import annotations
25+
26+
import argparse
27+
import logging
28+
import os
29+
import subprocess
30+
import sys
31+
import time
32+
from dataclasses import dataclass, field
33+
from pathlib import Path
34+
35+
logging.basicConfig(
36+
level=logging.INFO,
37+
format="%(asctime)s [%(levelname)s] %(message)s",
38+
datefmt="%H:%M:%S",
39+
)
40+
log = logging.getLogger("full_metric")
41+
42+
NUM_STEPS = 40
43+
44+
T2V_PROMPT = (
45+
"Two anthropomorphic cats in comfy boxing gear and bright gloves "
46+
"fight intensely on a spotlighted stage."
47+
)
48+
I2V_PROMPT = "A cat surfing on a wave"
49+
I2V_IMAGE_URL = (
50+
"https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B"
51+
"/resolve/main/examples/i2v_input.JPG"
52+
)
53+
NEGATIVE_PROMPT = "low quality"
54+
55+
RESOLUTIONS: list[dict[str, str | int]] = [
56+
{"height": 720, "width": 1280, "num_frames": 81, "label": "1280x720"},
57+
{"height": 1280, "width": 720, "num_frames": 81, "label": "720x1280"},
58+
]
59+
60+
61+
@dataclass
62+
class ModelConfig:
63+
name: str
64+
repo_id: str
65+
mode: str
66+
guidance_scale: float
67+
guidance_scale_2: float | None
68+
69+
70+
MODELS: dict[str, ModelConfig] = {
71+
"wan2.2-t2v-a14b": ModelConfig(
72+
name="wan2.2-t2v-a14b",
73+
repo_id="Wan-AI/Wan2.2-T2V-A14B-Diffusers",
74+
mode="t2v",
75+
guidance_scale=4.0,
76+
guidance_scale_2=3.0,
77+
),
78+
"wan2.2-i2v-a14b": ModelConfig(
79+
name="wan2.2-i2v-a14b",
80+
repo_id="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
81+
mode="i2v",
82+
guidance_scale=4.0,
83+
guidance_scale_2=3.0,
84+
),
85+
"wan2.1-t2v-14b": ModelConfig(
86+
name="wan2.1-t2v-14b",
87+
repo_id="Wan-AI/Wan2.1-T2V-14B-Diffusers",
88+
mode="t2v",
89+
guidance_scale=5.0,
90+
guidance_scale_2=None,
91+
),
92+
"wan2.1-i2v-14b": ModelConfig(
93+
name="wan2.1-i2v-14b",
94+
repo_id="Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
95+
mode="i2v",
96+
guidance_scale=5.0,
97+
guidance_scale_2=None,
98+
),
99+
}
100+
101+
102+
@dataclass
103+
class TimingResult:
104+
model: str
105+
label: str
106+
e2e_seconds: float
107+
components: dict[str, float] = field(default_factory=dict)
108+
109+
110+
def _parse_profiling(output: str) -> dict[str, float]:
111+
components: dict[str, float] = {}
112+
in_methods = False
113+
for line in output.splitlines():
114+
if "Method Timings:" in line:
115+
in_methods = True
116+
continue
117+
if in_methods and "===" in line:
118+
break
119+
if in_methods:
120+
parts = line.split()
121+
if len(parts) >= 3:
122+
try:
123+
total_ms = float(parts[-2])
124+
float(parts[-1]) # validate avg
125+
int(parts[-3]) # validate calls
126+
name = " ".join(parts[:-3])
127+
components[name] = total_ms / 1000.0
128+
except (ValueError, IndexError):
129+
pass
130+
return components
131+
132+
133+
def run(models: list[ModelConfig], output_dir: Path) -> list[TimingResult]:
134+
results: list[TimingResult] = []
135+
bazel_target = "//max/examples/diffusion:simple_offline_video_generation"
136+
total = len(models) * len(RESOLUTIONS)
137+
idx = 0
138+
139+
for model in models:
140+
for res in RESOLUTIONS:
141+
idx += 1
142+
label = str(res["label"])
143+
tag = f"{model.name}/{label}"
144+
log.info("(%d/%d) %s — starting", idx, total, tag)
145+
t0 = time.perf_counter()
146+
147+
prompt = I2V_PROMPT if model.mode == "i2v" else T2V_PROMPT
148+
video_path = output_dir / f"{model.name}_{label}.mp4"
149+
cmd = [
150+
"./bazelw",
151+
"run",
152+
bazel_target,
153+
"--",
154+
"--model",
155+
model.repo_id,
156+
"--prompt",
157+
prompt,
158+
"--negative-prompt",
159+
NEGATIVE_PROMPT,
160+
"--height",
161+
str(res["height"]),
162+
"--width",
163+
str(res["width"]),
164+
"--num-frames",
165+
str(res["num_frames"]),
166+
"--num-inference-steps",
167+
str(NUM_STEPS),
168+
"--guidance-scale",
169+
str(model.guidance_scale),
170+
"--output",
171+
str(video_path),
172+
]
173+
if model.guidance_scale_2 is not None:
174+
cmd += ["--guidance-scale-2", str(model.guidance_scale_2)]
175+
if model.mode == "i2v":
176+
cmd += ["--input-image", I2V_IMAGE_URL]
177+
178+
env = os.environ.copy()
179+
180+
proc = subprocess.run(
181+
cmd,
182+
capture_output=True,
183+
text=True,
184+
timeout=7200,
185+
env=env,
186+
cwd=str(Path(__file__).resolve().parents[3]),
187+
)
188+
elapsed = time.perf_counter() - t0
189+
full = proc.stdout + proc.stderr
190+
print(full[-2000:] if len(full) > 2000 else full)
191+
192+
if proc.returncode != 0:
193+
log.error("%s FAILED (%.0fs)", tag, elapsed)
194+
results.append(TimingResult(model.name, label, -1.0))
195+
continue
196+
197+
components = _parse_profiling(full)
198+
e2e = components.pop("E2E execute", components.pop("E2E", -1.0))
199+
log.info("%s — E2E %.1fs (total %.0fs)", tag, e2e, elapsed)
200+
results.append(TimingResult(model.name, label, e2e, components))
201+
return results
202+
203+
204+
def _gpu_name() -> str:
205+
try:
206+
return (
207+
subprocess.check_output(
208+
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
209+
text=True,
210+
)
211+
.strip()
212+
.splitlines()[0]
213+
)
214+
except Exception:
215+
return "unknown GPU"
216+
217+
218+
def print_summary(all_results: list[TimingResult]) -> None:
219+
gpu = _gpu_name()
220+
model_names = list(dict.fromkeys(r.model for r in all_results))
221+
222+
print(f"\n{'=' * 60}")
223+
print(f" Wan Full Metric — {gpu}, {NUM_STEPS} steps")
224+
print(f"{'=' * 60}\n")
225+
226+
hdr = f"{'Model':<22} {'Resolution':<12} {'E2E (s)':>10}"
227+
print(hdr)
228+
print("-" * len(hdr))
229+
230+
by_key: dict[tuple[str, str], TimingResult] = {}
231+
for r in all_results:
232+
by_key[(r.model, r.label)] = r
233+
234+
for model_name in model_names:
235+
e2e_vals: list[float] = []
236+
for res in RESOLUTIONS:
237+
label = str(res["label"])
238+
result = by_key.get((model_name, label))
239+
e2e = result.e2e_seconds if result else -1
240+
e2e_str = f"{e2e:>10.1f}" if e2e > 0 else f"{'FAIL':>10}"
241+
print(f"{model_name:<22} {label:<12} {e2e_str}")
242+
if e2e > 0:
243+
e2e_vals.append(e2e)
244+
if e2e_vals:
245+
avg = sum(e2e_vals) / len(e2e_vals)
246+
print(f"{model_name:<22} {'avg':<12} {avg:>10.1f}")
247+
print()
248+
249+
print(f"{'=' * 60}")
250+
print(" Component Breakdown (seconds)")
251+
print(f"{'=' * 60}\n")
252+
253+
for model_name in model_names:
254+
# Per-resolution breakdown
255+
for res in RESOLUTIONS:
256+
label = str(res["label"])
257+
result = by_key.get((model_name, label))
258+
if not result or not result.components:
259+
continue
260+
print(f" {model_name} / {label}:")
261+
for comp, secs in sorted(result.components.items()):
262+
print(f" {comp:<30} {secs:>10.3f}s")
263+
print()
264+
265+
# Average across resolutions
266+
comp_totals: dict[str, list[float]] = {}
267+
for res in RESOLUTIONS:
268+
label = str(res["label"])
269+
result = by_key.get((model_name, label))
270+
if not result or not result.components:
271+
continue
272+
for comp, secs in result.components.items():
273+
comp_totals.setdefault(comp, []).append(secs)
274+
if comp_totals:
275+
print(f" {model_name} / avg:")
276+
for comp in sorted(comp_totals):
277+
vals = comp_totals[comp]
278+
print(f" {comp:<30} {sum(vals)/len(vals):>10.3f}s")
279+
print()
280+
281+
282+
def main() -> None:
283+
parser = argparse.ArgumentParser(
284+
description="Wan video generation full metric (MAX)"
285+
)
286+
parser.add_argument(
287+
"--model",
288+
nargs="*",
289+
default=None,
290+
help=f"Model(s) to benchmark. Choices: {', '.join(MODELS)}. "
291+
"Default: all.",
292+
)
293+
parser.add_argument(
294+
"--output-dir",
295+
default="/tmp/wan_full_metric",
296+
)
297+
args = parser.parse_args()
298+
299+
selected = [MODELS[m] for m in (args.model or MODELS.keys())]
300+
output_dir = Path(args.output_dir)
301+
output_dir.mkdir(parents=True, exist_ok=True)
302+
303+
all_results = run(selected, output_dir)
304+
print_summary(all_results)
305+
306+
307+
if __name__ == "__main__":
308+
main()

0 commit comments

Comments
 (0)