Skip to content

Commit 7026d20

Browse files
committed
working code
1 parent d799d57 commit 7026d20

6 files changed

Lines changed: 559 additions & 10 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,6 @@ wandb
180180
# Gemini CLI
181181
.gemini/
182182
gha-creds-*.json
183+
184+
# JAX cache
185+
.jax_cache/

benchmark_attention.sh

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
#!/bin/bash
2+
# =============================================================================
3+
# Comprehensive Attention Benchmark: flash vs ring vs ulysses
4+
# Runs across i2v/t2v pipelines, multiple batch sizes.
5+
# Computes quality metrics (PSNR, SSIM) using flash as reference.
6+
# =============================================================================
7+
set -euo pipefail
8+
9+
# === Environment ===
10+
source /data/maxdiffusion-work/maxdiffusion-venv/bin/activate
11+
export PYTHONPATH=/data/maxdiffusion-work/maxdiffusion/src
12+
export HF_HOME=/data/maxdiffusion-work/hf-home
13+
export JAX_CACHE_DIR="${JAX_CACHE_DIR:-/data/maxdiffusion-work/jax-cache}"
14+
15+
REPO_DIR="/data/maxdiffusion-work/maxdiffusion"
16+
BENCH_DIR="/data/maxdiffusion-work/bench_results_$(date +%Y%m%d_%H%M%S)"
17+
LOG_DIR="${BENCH_DIR}/logs"
18+
VIDEO_DIR="${BENCH_DIR}/videos"
19+
mkdir -p "$LOG_DIR" "$VIDEO_DIR"
20+
21+
GEN_SCRIPT="${REPO_DIR}/src/maxdiffusion/generate_wan.py"
22+
T2V_CONFIG="${REPO_DIR}/src/maxdiffusion/configs/base_wan_14b.yml"
23+
I2V_CONFIG="${REPO_DIR}/src/maxdiffusion/configs/base_wan_i2v_14b.yml"
24+
25+
# === Benchmark Parameters ===
26+
ATTENTION_MODES=("flash" "ring" "ulysses")
27+
# per_device_batch_size -> global batch size on 8 devices:
28+
# 0.125->1, 0.25->2, 0.5->4, 1->8, 2->16
29+
BATCH_SIZES=("0.125" "0.25" "0.5" "1" "2")
30+
31+
# Smaller block sizes that fit in VMEM at cp=8
32+
FLASH_BLOCK_SIZES='{"block_q":2048,"block_kv_compute":1024,"block_kv":4096,"block_q_dkv":2048,"block_kv_dkv":4096,"block_kv_dkv_compute":1024,"use_fused_bwd_kernel":true}'
33+
34+
RESULTS_FILE="${BENCH_DIR}/timing_results.csv"
35+
echo "pipeline,attention,batch_size,compile_time_s,inference_time_s" > "$RESULTS_FILE"
36+
37+
SUMMARY_FILE="${BENCH_DIR}/summary.txt"
38+
39+
echo "============================================================"
40+
echo " Attention Benchmark Suite"
41+
echo " Started: $(date)"
42+
echo " Output: ${BENCH_DIR}"
43+
echo "============================================================"
44+
45+
# === Helper: run one benchmark ===
46+
run_benchmark() {
47+
local pipeline=$1 attention=$2 batch_size=$3
48+
local config_file run_name log_file
49+
50+
if [[ "$pipeline" == "t2v" ]]; then
51+
config_file="$T2V_CONFIG"
52+
else
53+
config_file="$I2V_CONFIG"
54+
fi
55+
56+
run_name="${pipeline}_${attention}_bs${batch_size}"
57+
log_file="${LOG_DIR}/${run_name}.log"
58+
59+
echo ""
60+
echo "=========================================="
61+
echo " RUN: ${run_name}"
62+
echo " Config: ${config_file}"
63+
echo " Log: ${log_file}"
64+
echo "=========================================="
65+
66+
# Run generation; videos are saved in cwd
67+
cd "$VIDEO_DIR"
68+
python "$GEN_SCRIPT" "$config_file" \
69+
attention="$attention" \
70+
num_inference_steps=50 \
71+
num_frames=81 \
72+
ici_data_parallelism=1 \
73+
ici_context_parallelism=8 \
74+
ici_tensor_parallelism=1 \
75+
allow_split_physical_axes=True \
76+
per_device_batch_size="$batch_size" \
77+
seed=0 \
78+
run_name="$run_name" \
79+
jax_cache_dir="$JAX_CACHE_DIR" \
80+
flash_block_sizes="$FLASH_BLOCK_SIZES" \
81+
2>&1 | tee "$log_file"
82+
83+
# Rename output videos to include run info
84+
for f in wan_output_*.mp4; do
85+
if [[ -e "$f" ]]; then
86+
mv "$f" "${run_name}_${f}"
87+
fi
88+
done
89+
# Also catch prefixed outputs
90+
for f in *wan_output_*.mp4; do
91+
if [[ -e "$f" && ! "$f" =~ ^(i2v|t2v)_ ]]; then
92+
mv "$f" "${run_name}_${f}" 2>/dev/null || true
93+
fi
94+
done
95+
96+
cd "$REPO_DIR"
97+
98+
# Extract timing from log
99+
local compile_time inference_time
100+
compile_time=$(grep -oP 'compile_time:\s*\K[0-9.]+' "$log_file" 2>/dev/null | tail -1 || echo "N/A")
101+
inference_time=$(grep -oP 'generation_time:\s*\K[0-9.]+' "$log_file" 2>/dev/null | tail -1 || echo "N/A")
102+
103+
echo "${pipeline},${attention},${batch_size},${compile_time},${inference_time}" >> "$RESULTS_FILE"
104+
echo " >> compile=${compile_time}s inference=${inference_time}s"
105+
}
106+
107+
# =============================================================================
108+
# Phase 1: Run all benchmarks
109+
# =============================================================================
110+
echo ""
111+
echo "============================================================"
112+
echo " Phase 1: Running inference benchmarks"
113+
echo "============================================================"
114+
115+
for batch_size in "${BATCH_SIZES[@]}"; do
116+
echo ""
117+
echo "--- Batch size: ${batch_size} ---"
118+
119+
# I2V benchmarks
120+
for attention in "${ATTENTION_MODES[@]}"; do
121+
run_benchmark "i2v" "$attention" "$batch_size" || {
122+
echo "FAILED: i2v $attention bs=$batch_size"
123+
echo "i2v,${attention},${batch_size},FAILED,FAILED" >> "$RESULTS_FILE"
124+
}
125+
done
126+
127+
# T2V benchmarks
128+
for attention in "${ATTENTION_MODES[@]}"; do
129+
run_benchmark "t2v" "$attention" "$batch_size" || {
130+
echo "FAILED: t2v $attention bs=$batch_size"
131+
echo "t2v,${attention},${batch_size},FAILED,FAILED" >> "$RESULTS_FILE"
132+
}
133+
done
134+
done
135+
136+
# =============================================================================
137+
# Phase 2: Quality comparison (PSNR, SSIM)
138+
# =============================================================================
139+
echo ""
140+
echo "============================================================"
141+
echo " Phase 2: Quality Metrics (PSNR, SSIM)"
142+
echo "============================================================"
143+
144+
python "$REPO_DIR/benchmark_quality.py" "$VIDEO_DIR" "$RESULTS_FILE" 2>&1 | tee "${LOG_DIR}/quality.log"
145+
146+
# =============================================================================
147+
# Phase 3: Final Summary
148+
# =============================================================================
149+
{
150+
echo "============================================================"
151+
echo " ATTENTION BENCHMARK RESULTS"
152+
echo " Date: $(date)"
153+
echo " Branch: $(git -C "$REPO_DIR" branch --show-current)"
154+
echo " Devices: $(python -c 'import jax; print(f"{len(jax.devices())} x {jax.devices()[0].device_kind}")')"
155+
echo "============================================================"
156+
echo ""
157+
echo "--- Timing Results ---"
158+
column -t -s',' "$RESULTS_FILE"
159+
echo ""
160+
161+
QUALITY_CSV="${BENCH_DIR}/quality_results.csv"
162+
if [[ -f "$QUALITY_CSV" ]]; then
163+
echo "--- Quality Results (flash = reference) ---"
164+
column -t -s',' "$QUALITY_CSV"
165+
echo ""
166+
fi
167+
168+
echo "--- Speedup vs Flash ---"
169+
python3 -c "
170+
import csv, sys
171+
rows = []
172+
with open('$RESULTS_FILE') as f:
173+
reader = csv.DictReader(f)
174+
for r in reader:
175+
rows.append(r)
176+
177+
# Group by (pipeline, batch_size)
178+
groups = {}
179+
for r in rows:
180+
key = (r['pipeline'], r['batch_size'])
181+
groups.setdefault(key, {})[r['attention']] = r
182+
183+
for key in sorted(groups):
184+
pipeline, bs = key
185+
flash = groups[key].get('flash', {})
186+
flash_inf = flash.get('inference_time_s', 'N/A')
187+
if flash_inf in ('N/A', 'FAILED'):
188+
continue
189+
flash_inf = float(flash_inf)
190+
print(f' {pipeline} bs={bs} (flash baseline: {flash_inf:.1f}s)')
191+
for attn in ['ring', 'ulysses']:
192+
data = groups[key].get(attn, {})
193+
inf = data.get('inference_time_s', 'N/A')
194+
if inf in ('N/A', 'FAILED'):
195+
print(f' {attn}: N/A')
196+
continue
197+
inf = float(inf)
198+
speedup = ((flash_inf - inf) / flash_inf) * 100
199+
sign = '+' if speedup >= 0 else ''
200+
print(f' {attn}: {inf:.1f}s ({sign}{speedup:.1f}%)')
201+
" 2>/dev/null || echo " (could not compute speedups)"
202+
echo ""
203+
echo "Results dir: ${BENCH_DIR}"
204+
echo "============================================================"
205+
} | tee "$SUMMARY_FILE"
206+
207+
echo ""
208+
echo "Benchmark complete at $(date)"
209+
echo "Full summary: ${SUMMARY_FILE}"

benchmark_quality.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Compare video quality across attention kernels.
2+
3+
Reads mp4 files, computes per-frame PSNR and SSIM, and appends results to a CSV.
4+
Usage: python benchmark_quality.py <video_dir> <results_csv>
5+
"""
6+
7+
import csv
8+
import glob
9+
import os
10+
import re
11+
import sys
12+
13+
import cv2
14+
import numpy as np
15+
from skimage.metrics import structural_similarity as ssim
16+
17+
18+
def read_video_frames(path, max_frames=None):
19+
"""Read video frames as numpy arrays."""
20+
cap = cv2.VideoCapture(path)
21+
frames = []
22+
while cap.isOpened():
23+
ret, frame = cap.read()
24+
if not ret:
25+
break
26+
frames.append(frame)
27+
if max_frames and len(frames) >= max_frames:
28+
break
29+
cap.release()
30+
return frames
31+
32+
33+
def compute_psnr(ref_frame, test_frame):
34+
mse = np.mean((ref_frame.astype(np.float64) - test_frame.astype(np.float64)) ** 2)
35+
if mse == 0:
36+
return float("inf")
37+
return 10 * np.log10(255.0**2 / mse)
38+
39+
40+
def compute_metrics(ref_path, test_path):
41+
"""Compute average PSNR and SSIM between two videos."""
42+
ref_frames = read_video_frames(ref_path)
43+
test_frames = read_video_frames(test_path)
44+
45+
if not ref_frames or not test_frames:
46+
return None, None
47+
48+
n = min(len(ref_frames), len(test_frames))
49+
psnr_vals = []
50+
ssim_vals = []
51+
52+
for i in range(n):
53+
ref = ref_frames[i]
54+
tst = test_frames[i]
55+
# Resize if dimensions differ
56+
if ref.shape != tst.shape:
57+
tst = cv2.resize(tst, (ref.shape[1], ref.shape[0]))
58+
59+
psnr_vals.append(compute_psnr(ref, tst))
60+
# SSIM on grayscale
61+
ref_gray = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY)
62+
tst_gray = cv2.cvtColor(tst, cv2.COLOR_BGR2GRAY)
63+
ssim_vals.append(ssim(ref_gray, tst_gray))
64+
65+
return np.mean(psnr_vals), np.mean(ssim_vals)
66+
67+
68+
def find_video_pairs(video_dir):
69+
"""Find flash (reference) vs ring/ulysses video pairs.
70+
71+
Expects filenames like: {pipeline}_{attention}_bs{batch}_wan_output_0_0.mp4
72+
"""
73+
videos = glob.glob(os.path.join(video_dir, "*.mp4"))
74+
# Group by (pipeline, batch_size)
75+
groups = {}
76+
for v in videos:
77+
base = os.path.basename(v)
78+
# Match pattern: {pipeline}_{attention}_bs{batch}_wan_output_{seed}_{idx}.mp4
79+
m = re.match(r"^(i2v|t2v)_(flash|ring|ulysses)_(bs[\d.]+)_wan_output_(\d+)_(\d+)\.mp4$", base)
80+
if not m:
81+
continue
82+
pipeline, attention, batch, seed, idx = m.groups()
83+
key = (pipeline, batch, seed, idx)
84+
groups.setdefault(key, {})[attention] = v
85+
86+
pairs = []
87+
for key, attn_map in sorted(groups.items()):
88+
if "flash" not in attn_map:
89+
continue
90+
ref = attn_map["flash"]
91+
for attn in ["ring", "ulysses"]:
92+
if attn in attn_map:
93+
pairs.append((key, attn, ref, attn_map[attn]))
94+
return pairs
95+
96+
97+
def main():
98+
video_dir = sys.argv[1]
99+
results_csv = sys.argv[2]
100+
101+
pairs = find_video_pairs(video_dir)
102+
if not pairs:
103+
print("No video pairs found for quality comparison.")
104+
return
105+
106+
quality_csv = os.path.join(os.path.dirname(results_csv), "quality_results.csv")
107+
with open(quality_csv, "w", newline="") as f:
108+
writer = csv.writer(f)
109+
writer.writerow(["pipeline", "test_attention", "batch_size", "avg_psnr_db", "avg_ssim"])
110+
for key, attn, ref_path, test_path in pairs:
111+
pipeline, batch, seed, idx = key
112+
print(f"Comparing: flash vs {attn} | {pipeline} {batch} (video {idx})")
113+
avg_psnr, avg_ssim = compute_metrics(ref_path, test_path)
114+
if avg_psnr is not None:
115+
print(f" PSNR: {avg_psnr:.2f} dB | SSIM: {avg_ssim:.4f}")
116+
writer.writerow([pipeline, attn, batch, f"{avg_psnr:.2f}", f"{avg_ssim:.4f}"])
117+
else:
118+
print(" Could not compute metrics (missing/empty video)")
119+
writer.writerow([pipeline, attn, batch, "N/A", "N/A"])
120+
121+
print(f"\nQuality results saved to: {quality_csv}")
122+
123+
# Print summary table
124+
print("\n" + "=" * 60)
125+
print(" QUALITY SUMMARY (flash = reference)")
126+
print("=" * 60)
127+
with open(quality_csv) as f:
128+
for line in f:
129+
print(f" {line.strip()}")
130+
print("=" * 60)
131+
132+
133+
if __name__ == "__main__":
134+
main()

src/maxdiffusion/common_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,13 @@
8484
[CROSS_ATTN_Q_LENGTH, CONTEXT],
8585
[CROSS_ATTN_KV_LENGTH, None],
8686
]
87+
88+
### Common axis rules for ulysses attention ###
89+
ULYSSES_ATTENTION_AXIS_RULES = [
90+
[SELF_ATTN_HEAD, None],
91+
[SELF_ATTN_Q_LENGTH, CONTEXT],
92+
[SELF_ATTN_KV_LENGTH, CONTEXT],
93+
[CROSS_ATTN_HEAD, None],
94+
[CROSS_ATTN_Q_LENGTH, CONTEXT],
95+
[CROSS_ATTN_KV_LENGTH, CONTEXT],
96+
]

0 commit comments

Comments
 (0)