Skip to content

Commit 04574b4

Browse files
sagarchaparaclaude
andcommitted
feat: add attention benchmark matrix script and ops guide
bench_attn.sh: 6 attention strategies × 4 batch sizes (bs 1/2/4/8) on v7x-16, profiler only for bs=2. Updated XLA flags: LHS rerun=5, max_concurrent_async_collective_permutes=16, ici_ag_pipelining=true. docs/tpu_wan_bench_guide.md: how to launch benchmarks, view xprof traces, sync code to worker1, clear TPU locks after crashes, and parse results. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 6604952 commit 04574b4

3 files changed

Lines changed: 352 additions & 2 deletions

File tree

bench_attn.sh

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
# ── config ────────────────────────────────────────────────────────────────────
5+
REPO_DIR=/dev/shm/maxdiffusion
6+
VENV=$REPO_DIR/.venv-tpu
7+
CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml
8+
RESULTS_ROOT=$REPO_DIR/bench_results
9+
OUTPUT_ROOT=$REPO_DIR/bench_outputs
10+
11+
WORKER1_USER=sa_112155357684894056033
12+
WORKER1_IP=10.154.0.59
13+
WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB
14+
SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine
15+
SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts
16+
SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \
17+
-o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \
18+
-o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS"
19+
20+
export LIBTPU_INIT_ARGS=\
21+
'--xla_tpu_dvfs_p_state=7 '\
22+
'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\
23+
'--xla_tpu_megacore_fusion_allow_ags=false '\
24+
'--xla_enable_async_collective_permute=true '\
25+
'--xla_tpu_enable_ag_backward_pipelining=true '\
26+
'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\
27+
'--xla_tpu_data_parallel_opt_different_sized_ops=true '\
28+
'--xla_tpu_enable_async_collective_fusion=true '\
29+
'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\
30+
'--xla_tpu_overlap_compute_collective_tc=true '\
31+
'--xla_enable_async_all_gather=true '\
32+
'--xla_tpu_scoped_vmem_limit_kib=65536 '\
33+
'--xla_tpu_enable_async_all_to_all=true '\
34+
'--xla_tpu_enable_latency_hiding_scheduler=true '\
35+
'--xla_tpu_enable_all_experimental_scheduler_features=true '\
36+
'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\
37+
'--xla_tpu_host_transfer_overlap_limit=24 '\
38+
'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\
39+
'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\
40+
'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\
41+
'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\
42+
'--xla_max_concurrent_host_send_recv=100 '\
43+
'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\
44+
'--xla_latency_hiding_scheduler_rerun=5 '\
45+
'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\
46+
'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\
47+
'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\
48+
'--xla_tpu_assign_all_reduce_scatter_layout=true '\
49+
'--xla_max_concurrent_async_collective_permutes=16 '\
50+
'--xla_tpu_enable_ici_ag_pipelining=true'
51+
52+
source "$VENV/bin/activate"
53+
export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-}
54+
export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface
55+
export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub
56+
export HF_HUB_ENABLE_HF_TRANSFER=1
57+
export HF_HUB_OFFLINE=1
58+
export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax
59+
export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla
60+
export TMPDIR=/dev/shm/maxdiffusion_cache/tmp
61+
62+
# ── helper ────────────────────────────────────────────────────────────────────
63+
# run_case <run_name> <attention> <ici_dp> <ici_a> <ici_b> <per_device_batch_size>
64+
# pure (flash/tokamax_ring/ulysses): ici_a=cp, ici_b unused
65+
# hybrid (ulysses_ring): ici_a=ring, ici_b=ulysses
66+
run_case() {
67+
local run_name="$1"
68+
local attention="$2"
69+
local ici_dp="$3"
70+
local ici_a="$4"
71+
local ici_b="$5"
72+
local pdb="$6" # per_device_batch_size
73+
74+
local results_dir="$RESULTS_ROOT/$run_name"
75+
rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name"
76+
mkdir -p "$results_dir"
77+
rm -f /tmp/libtpu_lockfile
78+
mkdir -p "$TMPDIR"
79+
80+
local ici_cp ici_tp ici_ring ici_ulysses
81+
if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then
82+
ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b"
83+
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses pdb=$pdb) ──"
84+
else
85+
ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1
86+
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──"
87+
fi
88+
89+
# Only capture profiler for bs=2 (per_device_batch_size=0.125)
90+
local profiler_args="enable_profiler=False"
91+
if [[ "$pdb" == "0.125" ]]; then
92+
profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10"
93+
fi
94+
95+
local common_args="run_name=$run_name \
96+
attention=$attention \
97+
ici_data_parallelism=$ici_dp \
98+
ici_fsdp_parallelism=1 \
99+
ici_context_parallelism=$ici_cp \
100+
ici_tensor_parallelism=$ici_tp \
101+
ici_ring_parallelism=$ici_ring \
102+
ici_ulysses_parallelism=$ici_ulysses \
103+
dcn_data_parallelism=1 \
104+
dcn_fsdp_parallelism=1 \
105+
dcn_context_parallelism=1 \
106+
dcn_tensor_parallelism=1 \
107+
dcn_ring_parallelism=1 \
108+
dcn_ulysses_parallelism=1 \
109+
height=720 width=1280 num_frames=81 num_inference_steps=40 \
110+
per_device_batch_size=$pdb \
111+
output_dir=$OUTPUT_ROOT \
112+
scan_layers=True \
113+
write_metrics=False \
114+
write_timing_metrics=False \
115+
$profiler_args"
116+
117+
local remote_cmd
118+
remote_cmd="$(printf '%q' "set -euo pipefail
119+
export LIBTPU_INIT_ARGS='$LIBTPU_INIT_ARGS'
120+
source $VENV/bin/activate
121+
export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH:-}
122+
export HF_HOME=$HF_HOME
123+
export HF_HUB_CACHE=$HF_HUB_CACHE
124+
export HF_HUB_ENABLE_HF_TRANSFER=1
125+
export HF_HUB_OFFLINE=1
126+
export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR
127+
export XLA_CACHE_DIR=$XLA_CACHE_DIR
128+
export TMPDIR=$TMPDIR
129+
rm -f /tmp/libtpu_lockfile
130+
mkdir -p $TMPDIR $results_dir $OUTPUT_ROOT
131+
cd $results_dir
132+
python -u $REPO_DIR/src/maxdiffusion/generate_wan.py $CONFIG $common_args 2>&1 | tee $results_dir/worker1.log")"
133+
134+
/usr/bin/ssh $SSH_OPTS "$WORKER1_USER@$WORKER1_IP" "bash -lc $remote_cmd" &
135+
local remote_pid=$!
136+
137+
cd "$results_dir"
138+
set +e
139+
python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \
140+
"$CONFIG" \
141+
$common_args \
142+
2>&1 | tee "$results_dir/worker0.log"
143+
local local_status=$?
144+
145+
wait "$remote_pid"
146+
local remote_status=$?
147+
set -e
148+
149+
echo "[$(date -u +%T)] $run_name done — local=$local_status remote=$remote_status"
150+
echo "────────────────────────────────────────────────────────────────────────"
151+
}
152+
153+
# ── run matrix ────────────────────────────────────────────────────────────────
154+
# Columns: run_name attention dp a b per_device_bs
155+
#
156+
# All runs: 16 total devices (2 hosts × 8), dp=2, cp=8 (or ring×ulysses=8)
157+
# Batch sizes: 0.0625=bs1, 0.125=bs2, 0.25=bs4
158+
159+
for pdb in 0.0625 0.125 0.25 0.5; do
160+
bs=$(python3 -c "print(int($pdb * 16))")
161+
run_case "flash_dp2_cp8_bs${bs}" flash 2 8 1 $pdb
162+
run_case "tokamax_ring_dp2_cp8_bs${bs}" tokamax_ring 2 8 1 $pdb
163+
run_case "ulysses_dp2_cp8_bs${bs}" ulysses 2 8 1 $pdb
164+
run_case "ulysses_ring_dp2_r2u4_bs${bs}" ulysses_ring 2 2 4 $pdb # ring=2, ulysses=4, dp=2
165+
run_case "ulysses_ring_dp2_r4u2_bs${bs}" ulysses_ring 2 4 2 $pdb # ring=4, ulysses=2, dp=2
166+
run_case "ulysses_ring_dp1_r4u4_bs${bs}" ulysses_ring 1 4 4 $pdb # ring=4, ulysses=4, dp=1
167+
done
168+
169+
echo "[$(date -u +%T)] All benchmark runs complete."

bench_remaining.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export LIBTPU_INIT_ARGS=\
3131
'--xla_enable_async_all_gather=true '\
3232
'--xla_tpu_scoped_vmem_limit_kib=65536 '\
3333
'--xla_tpu_enable_async_all_to_all=true '\
34+
'--xla_tpu_enable_latency_hiding_scheduler=true '\
3435
'--xla_tpu_enable_all_experimental_scheduler_features=true '\
3536
'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\
3637
'--xla_tpu_host_transfer_overlap_limit=24 '\
@@ -40,11 +41,12 @@ export LIBTPU_INIT_ARGS=\
4041
'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\
4142
'--xla_max_concurrent_host_send_recv=100 '\
4243
'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\
43-
'--xla_latency_hiding_scheduler_rerun=2 '\
44+
'--xla_latency_hiding_scheduler_rerun=5 '\
4445
'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\
4546
'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\
4647
'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\
47-
'--xla_tpu_assign_all_reduce_scatter_layout=true'
48+
'--xla_tpu_assign_all_reduce_scatter_layout=true '\
49+
'--xla_max_concurrent_async_collective_permutes=16'
4850

4951
source "$VENV/bin/activate"
5052
export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-}

docs/tpu_wan_bench_guide.md

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# WAN 2.2 Attention Benchmarking Guide (TPU v7x-16)
2+
3+
## Setup
4+
5+
Two-host TPU v7x slice: 2 × 8 chips = 16 devices total.
6+
7+
- **Host 0** (local): coordinator, runs `generate_wan.py` + SSH to host 1
8+
- **Host 1** (worker): `sa_112155357684894056033@10.154.0.59`
9+
10+
## Quick Start
11+
12+
### Run the full attention benchmark
13+
14+
```bash
15+
nohup bash /dev/shm/maxdiffusion/bench_attn.sh \
16+
> /dev/shm/maxdiffusion/bench_results/bench_attn.log 2>&1 &
17+
18+
# Monitor
19+
tail -f /dev/shm/maxdiffusion/bench_results/bench_attn.log
20+
21+
# Check results as they come in
22+
grep -E "compile_time|Inference:|done —" /dev/shm/maxdiffusion/bench_results/bench_attn.log
23+
```
24+
25+
### Run a single attention mode
26+
27+
Use the one-shot helper scripts in `/tmp/`:
28+
29+
```bash
30+
# Ulysses dp=2 cp=8
31+
nohup bash /tmp/run_ulysses_v3.sh > /dev/shm/maxdiffusion/bench_results/myrun.log 2>&1 &
32+
33+
# Monitor
34+
tail -f /dev/shm/maxdiffusion/bench_results/myrun.log
35+
```
36+
37+
Or call `generate_wan.py` directly on both hosts (see [Multihost section](#multihost-runs)).
38+
39+
---
40+
41+
## Benchmark Matrix (`bench_attn.sh`)
42+
43+
6 attention strategies × 4 batch sizes = 24 runs.
44+
Profiler captured only for bs=2 to avoid disk overhead.
45+
46+
| Attention | dp | cp/ring/ulysses | Notes |
47+
|-----------|-----|-----------------|-------|
48+
| `flash` | 2 | cp=8 | local flash per shard, no KV rotation |
49+
| `tokamax_ring` | 2 | cp=8 | Tokamax ring kernel, KV rotated across context axis |
50+
| `ulysses` | 2 | cp=8 | Ulysses all-to-all in BSHD layout |
51+
| `ulysses_ring` | 2 | ring=2, ulysses=4 | Hybrid 2D: ulysses intra-chip, ring cross-chip |
52+
| `ulysses_ring` | 2 | ring=4, ulysses=2 | Hybrid 2D: alternative split |
53+
| `ulysses_ring` | 1 | ring=4, ulysses=4 | Hybrid 2D: full 16-chip seq sharding |
54+
55+
Batch sizes: `per_device_batch_size` ∈ {0.0625, 0.125, 0.25, 0.5}
56+
→ total videos = per_device_batch_size × 16 devices ∈ {1, 2, 4, 8}
57+
58+
Results land in:
59+
```
60+
/dev/shm/maxdiffusion/bench_results/<run_name>/worker0.log
61+
/dev/shm/maxdiffusion/bench_results/<run_name>/worker1.log
62+
```
63+
64+
---
65+
66+
## View Profiler Traces (xprof)
67+
68+
```bash
69+
source /dev/shm/maxdiffusion/.venv-tpu/bin/activate
70+
71+
python -c "
72+
from xprof.server import main
73+
import sys
74+
sys.argv=['xprof',
75+
'--logdir', '/dev/shm/maxdiffusion/bench_outputs/<run_name>/<run_name>/tensorboard',
76+
'--port', '6006']
77+
main()
78+
"
79+
```
80+
81+
Then open **http://localhost:6006/**
82+
83+
Pick the **latest timestamp** under `plugins/profile/` — it's from the warm (second) inference round with XLA cache hot.
84+
85+
To switch runs, kill the server (`pkill -f "xprof"`) and relaunch with the new logdir.
86+
87+
---
88+
89+
## Multihost Runs
90+
91+
Both hosts must launch `generate_wan.py` simultaneously — JAX distributed requires all workers to connect within the timeout window (~5 min).
92+
93+
The benchmark scripts handle this automatically: the local host runs the job directly while SSH-ing the same command to worker 1 in the background, then `wait`s for both.
94+
95+
### Sync code changes to worker 1
96+
97+
```bash
98+
SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine
99+
SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts
100+
WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB
101+
SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \
102+
-o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \
103+
-o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS"
104+
105+
rsync -a --exclude='__pycache__' --exclude='*.pyc' \
106+
-e "/usr/bin/ssh $SSH_OPTS" \
107+
/dev/shm/maxdiffusion/src/ sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion/src/
108+
```
109+
110+
**Always rsync before launching a run after code changes.**
111+
112+
### Clear TPU locks after a crash
113+
114+
If a run crashes mid-flight, the TPU vfio devices may stay locked:
115+
116+
```bash
117+
# Local
118+
ps -ef | grep generate_wan | grep -v grep | awk '{print $2}' | xargs -r kill -9
119+
rm -f /tmp/libtpu_lockfile
120+
121+
# Remote
122+
/usr/bin/ssh $SSH_OPTS sa_112155357684894056033@10.154.0.59 \
123+
'ps -ef | grep generate_wan | grep -v grep | awk '"'"'{print $2}'"'"' | xargs -r kill -9; rm -f /tmp/libtpu_lockfile'
124+
```
125+
126+
Then wait ~5 seconds before relaunching.
127+
128+
---
129+
130+
## Key Config Parameters
131+
132+
Set in `src/maxdiffusion/configs/base_wan_27b.yml` or overridden on the command line:
133+
134+
| Parameter | Description |
135+
|-----------|-------------|
136+
| `attention` | `flash`, `tokamax_ring`, `ulysses`, `ulysses_ring` |
137+
| `ici_data_parallelism` | Data parallel replicas within a host |
138+
| `ici_context_parallelism` | Sequence shards (flash/ring/ulysses) |
139+
| `ici_ring_parallelism` | Ring axis size (ulysses_ring only) |
140+
| `ici_ulysses_parallelism` | Ulysses axis size (ulysses_ring only) |
141+
| `per_device_batch_size` | Videos per device; total = × 16 |
142+
| `num_inference_steps` | Denoising steps (40 for full quality) |
143+
| `enable_profiler` | Capture xprof trace |
144+
| `skip_first_n_steps_for_profiler` | Warmup steps before profiling (5) |
145+
| `profiler_steps` | Steps to profile (10) |
146+
147+
**Parallelism rule**: product of all ICI axes must equal 8 (chips per host):
148+
- `ici_dp × ici_fsdp × ici_cp × ici_tp × ici_ring × ici_ulysses = 8`
149+
150+
For `ulysses_ring`, set `ici_context_parallelism=1` and use `ici_ring` + `ici_ulysses` instead.
151+
152+
---
153+
154+
## XLA Flags
155+
156+
All performance-critical flags are set in `LIBTPU_INIT_ARGS` in the benchmark scripts. Notable ones:
157+
158+
| Flag | Value | Purpose |
159+
|------|-------|---------|
160+
| `xla_tpu_enable_latency_hiding_scheduler` | true | Overlap compute and collectives |
161+
| `xla_latency_hiding_scheduler_rerun` | 5 | LHS scheduling passes |
162+
| `xla_enable_async_collective_permute` | true | Async KV rotation for ring attention |
163+
| `xla_max_concurrent_async_collective_permutes` | 16 | Max in-flight ring permutes |
164+
| `xla_tpu_enable_async_all_to_all` | true | Async Ulysses all-to-all |
165+
| `xla_tpu_enable_ici_ag_pipelining` | true | Pipeline ICI all-gathers |
166+
| `xla_tpu_scoped_vmem_limit_kib` | 65536 | VMEM budget per op |
167+
168+
---
169+
170+
## Result Parsing
171+
172+
```bash
173+
# Summary table across all runs
174+
grep -E "compile_time|Inference:|generation_time per video" \
175+
/dev/shm/maxdiffusion/bench_results/*/worker0.log
176+
177+
# Just inference time
178+
grep "Inference:" /dev/shm/maxdiffusion/bench_results/*/worker0.log | sort -t: -k3 -n
179+
```

0 commit comments

Comments
 (0)