Skip to content

Commit 6604952

Browse files
sagarchaparaclaude
andcommitted
feat: add ring/ulysses mesh axes and BSHD Ulysses attention for WAN 2.2
Adds dedicated `ring` and `ulysses` mesh axes so hybrid context parallelism does not interfere with weight TP sharding rules. Rewrites _ulysses_attention to keep tensors in BSHD layout through the all-to-all collective, avoiding a full-sequence BHSD transpose before the redistribution — ~34% faster than the original implementation and 8% faster than plain ring at dp=2 cp=8 on v7x-16. - common_types: add RING/ULYSSES axis name constants and axis rules for ulysses_ring hybrid (sequence sharded over [ring, ulysses]) - configs: add ici/dcn ring+ulysses parallelism params and updated mesh_axes, logical_axis_rules, data_sharding, and flash block sizes for all WAN configs - max_utils: extend create_device_mesh to append ring/ulysses axes when present - attention_flax: BSHD-native _ulysses_attention; new _ulysses_ring_attention combining Ulysses all-to-all (ulysses axis) with Tokamax ring (ring axis); routing and cross-attention fallback wired in _apply_attention - pyconfig: prepend ULYSSES_RING_ATTENTION_AXIS_RULES for ulysses_ring modes - bench_remaining.sh, docs: multihost benchmark harness and results Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent cddfebb commit 6604952

14 files changed

Lines changed: 815 additions & 63 deletions

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,23 @@ To generate images, run the following command:
597597

598598
Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning.
599599

600+
For TPU multihost 2D context parallelism, use `attention="ulysses_ring"`.
601+
This shards self-attention sequence over `context` x `tensor`, runs the Ulysses all-to-all over the `tensor`
602+
mesh axis, and reuses Tokamax ring attention over the `context` mesh axis. The number of attention heads must
603+
be divisible by `ici_tensor_parallelism`; a typical multihost setup uses DCN context for the ring axis and ICI
604+
tensor for the Ulysses axis. On TPU7x, keep `dcn_tensor_parallelism=1` and set `ici_tensor_parallelism >= 2`
605+
so the dual chiplets exposed as two JAX devices are grouped by Ulysses rather than ring.
606+
607+
```bash
608+
python src/maxdiffusion/generate_wan.py \
609+
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
610+
attention="ulysses_ring" \
611+
dcn_context_parallelism=<num_slices> \
612+
ici_context_parallelism=1 \
613+
ici_tensor_parallelism=<ulysses_shards_per_slice> \
614+
...
615+
```
616+
600617
In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup.
601618

602619
### Caching Mechanisms

bench_remaining.sh

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
# ── config ────────────────────────────────────────────────────────────────────
5+
REPO_DIR=/dev/shm/maxdiffusion
6+
VENV=$REPO_DIR/.venv-tpu
7+
CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml
8+
RESULTS_ROOT=$REPO_DIR/bench_results
9+
OUTPUT_ROOT=$REPO_DIR/bench_outputs
10+
11+
WORKER1_USER=sa_112155357684894056033
12+
WORKER1_IP=10.154.0.59
13+
WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB
14+
SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine
15+
SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts
16+
SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \
17+
-o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \
18+
-o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS"
19+
20+
export LIBTPU_INIT_ARGS=\
21+
'--xla_tpu_dvfs_p_state=7 '\
22+
'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\
23+
'--xla_tpu_megacore_fusion_allow_ags=false '\
24+
'--xla_enable_async_collective_permute=true '\
25+
'--xla_tpu_enable_ag_backward_pipelining=true '\
26+
'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\
27+
'--xla_tpu_data_parallel_opt_different_sized_ops=true '\
28+
'--xla_tpu_enable_async_collective_fusion=true '\
29+
'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\
30+
'--xla_tpu_overlap_compute_collective_tc=true '\
31+
'--xla_enable_async_all_gather=true '\
32+
'--xla_tpu_scoped_vmem_limit_kib=65536 '\
33+
'--xla_tpu_enable_async_all_to_all=true '\
34+
'--xla_tpu_enable_all_experimental_scheduler_features=true '\
35+
'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\
36+
'--xla_tpu_host_transfer_overlap_limit=24 '\
37+
'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\
38+
'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\
39+
'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\
40+
'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\
41+
'--xla_max_concurrent_host_send_recv=100 '\
42+
'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\
43+
'--xla_latency_hiding_scheduler_rerun=2 '\
44+
'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\
45+
'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\
46+
'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\
47+
'--xla_tpu_assign_all_reduce_scatter_layout=true'
48+
49+
source "$VENV/bin/activate"
50+
export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-}
51+
export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface
52+
export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub
53+
export HF_HUB_ENABLE_HF_TRANSFER=1
54+
export HF_HUB_OFFLINE=1
55+
export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax
56+
export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla
57+
export TMPDIR=/dev/shm/maxdiffusion_cache/tmp
58+
59+
# ── helper ────────────────────────────────────────────────────────────────────
60+
# For hybrid `ulysses_ring`, the 4th and 5th args are the ring and ulysses sizes
61+
# (which use dedicated mesh axes); context and tensor are pinned to 1.
62+
# For pure `ulysses` and `ring`, the 4th arg is context size and 5th is unused.
63+
run_case() {
64+
local run_name="$1"
65+
local attention="$2"
66+
local ici_dp="$3"
67+
local ici_a="$4" # context (pure) | ring (hybrid)
68+
local ici_b="$5" # unused (pure) | ulysses (hybrid)
69+
70+
local results_dir="$RESULTS_ROOT/$run_name"
71+
rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name"
72+
mkdir -p "$results_dir"
73+
rm -f /tmp/libtpu_lockfile
74+
mkdir -p "$TMPDIR"
75+
76+
local ici_cp ici_tp ici_ring ici_ulysses
77+
if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then
78+
ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b"
79+
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses) ──"
80+
else
81+
ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1
82+
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp) ──"
83+
fi
84+
85+
local common_args="run_name=$run_name \
86+
attention=$attention \
87+
ici_data_parallelism=$ici_dp \
88+
ici_fsdp_parallelism=1 \
89+
ici_context_parallelism=$ici_cp \
90+
ici_tensor_parallelism=$ici_tp \
91+
ici_ring_parallelism=$ici_ring \
92+
ici_ulysses_parallelism=$ici_ulysses \
93+
dcn_data_parallelism=1 \
94+
dcn_fsdp_parallelism=1 \
95+
dcn_context_parallelism=1 \
96+
dcn_tensor_parallelism=1 \
97+
dcn_ring_parallelism=1 \
98+
dcn_ulysses_parallelism=1 \
99+
height=720 width=1280 num_frames=81 num_inference_steps=40 \
100+
per_device_batch_size=0.125 \
101+
output_dir=$OUTPUT_ROOT \
102+
scan_layers=True \
103+
write_metrics=False \
104+
write_timing_metrics=False \
105+
enable_profiler=True \
106+
skip_first_n_steps_for_profiler=5 \
107+
profiler_steps=10"
108+
109+
# launch worker 1 in background via SSH
110+
local remote_cmd
111+
remote_cmd="$(printf '%q' "set -euo pipefail
112+
export LIBTPU_INIT_ARGS='$LIBTPU_INIT_ARGS'
113+
source $VENV/bin/activate
114+
export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH:-}
115+
export HF_HOME=$HF_HOME
116+
export HF_HUB_CACHE=$HF_HUB_CACHE
117+
export HF_HUB_ENABLE_HF_TRANSFER=1
118+
export HF_HUB_OFFLINE=1
119+
export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR
120+
export XLA_CACHE_DIR=$XLA_CACHE_DIR
121+
export TMPDIR=$TMPDIR
122+
rm -f /tmp/libtpu_lockfile
123+
mkdir -p $TMPDIR $results_dir $OUTPUT_ROOT
124+
cd $results_dir
125+
python -u $REPO_DIR/src/maxdiffusion/generate_wan.py $CONFIG $common_args 2>&1 | tee $results_dir/worker1.log")"
126+
127+
/usr/bin/ssh $SSH_OPTS "$WORKER1_USER@$WORKER1_IP" "bash -lc $remote_cmd" &
128+
local remote_pid=$!
129+
130+
cd "$results_dir"
131+
# Disable -e for the python call so a single run failing doesn't kill the whole sequence.
132+
set +e
133+
python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \
134+
"$CONFIG" \
135+
$common_args \
136+
2>&1 | tee "$results_dir/worker0.log"
137+
local local_status=$?
138+
139+
wait "$remote_pid"
140+
local remote_status=$?
141+
set -e
142+
143+
echo "[$(date -u +%T)] $run_name done — local=$local_status remote=$remote_status"
144+
echo "────────────────────────────────────────────────────────────────────────"
145+
}
146+
147+
# ── run matrix ────────────────────────────────────────────────────────────────
148+
# Pure runs: args = dp cp _ (cp on context axis)
149+
# Hybrid runs: args = dp ring ulysses (dedicated ring + ulysses axes)
150+
# run_name attention dp a b
151+
run_case ulysses_dp2_cp8 ulysses 2 8 1
152+
run_case ring_dp2_cp8 ring 2 8 1
153+
run_case ulysses_ring_dp2_cp8_4x2 ulysses_ring 2 4 2
154+
run_case ulysses_ring_dp2_cp8_2x4 ulysses_ring 2 2 4
155+
run_case ulysses_ring_dp1_cp16_4x4 ulysses_ring 1 4 4
156+
157+
echo "[$(date -u +%T)] All benchmark runs complete."

docs/tpu_multihost_wan_bench.md

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# TPU Multihost WAN Benchmarks
2+
3+
This note shows how to connect to a TPU v7x-16 multihost VM and run WAN attention comparisons on both workers.
4+
5+
The examples below assume:
6+
7+
- TPU name: `rish-tpu-7x16`
8+
- Zone: `europe-west2-a`
9+
- Project: `tpu-prod-env-one-vm`
10+
- Repo path on both workers: `/dev/shm/maxdiffusion`
11+
- Venv path on both workers: `/dev/shm/maxdiffusion/.venv-tpu`
12+
- HF cache path on both workers: `/dev/shm/maxdiffusion_cache/huggingface`
13+
14+
For the current `rish-tpu-7x16` allocation, the worker endpoints resolved to:
15+
16+
- worker 0: `sa_112155357684894056033@10.154.0.62`
17+
- worker 1: `sa_112155357684894056033@10.154.0.59`
18+
- worker 0 host key alias: `tpu.2884514015978940116-0-IE5JIi`
19+
- worker 1 host key alias: `tpu.2884514015978940116-1-jgVFrB`
20+
21+
## Connect
22+
23+
Set the project and zone:
24+
25+
```bash
26+
gcloud config set project tpu-prod-env-one-vm
27+
gcloud config set compute/zone europe-west2-a
28+
```
29+
30+
Check that both workers are reachable:
31+
32+
```bash
33+
gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \
34+
--worker=all \
35+
--internal-ip \
36+
--zone=europe-west2-a \
37+
--quiet \
38+
--command='hostname'
39+
```
40+
41+
Print the raw ssh command for a single worker when you need it:
42+
43+
```bash
44+
gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \
45+
--worker=0 \
46+
--internal-ip \
47+
--zone=europe-west2-a \
48+
--dry-run
49+
```
50+
51+
## Stage Code On Worker 1
52+
53+
If worker 0 already has the repo, venv, and WAN checkpoint cache, mirror them to worker 1 from worker 0:
54+
55+
```bash
56+
rsync -a --delete \
57+
--exclude='.git' \
58+
--exclude='.venv' \
59+
--exclude='.venv-tpu' \
60+
--exclude='__pycache__' \
61+
--exclude='*.pyc' \
62+
-e 'ssh -T -i ~/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=tpu.2884514015978940116-1-jgVFrB -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=~/.ssh/google_compute_known_hosts' \
63+
/dev/shm/maxdiffusion/ \
64+
sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion/
65+
```
66+
67+
```bash
68+
rsync -a --delete \
69+
-e 'ssh -T -i ~/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=tpu.2884514015978940116-1-jgVFrB -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=~/.ssh/google_compute_known_hosts' \
70+
/dev/shm/maxdiffusion/.venv-tpu/ \
71+
sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion/.venv-tpu/
72+
```
73+
74+
```bash
75+
rsync -a --partial --delete \
76+
-e 'ssh -T -i ~/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=tpu.2884514015978940116-1-jgVFrB -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=~/.ssh/google_compute_known_hosts' \
77+
/dev/shm/maxdiffusion_cache/huggingface/hub/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/ \
78+
sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion_cache/huggingface/hub/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/
79+
```
80+
81+
## Smoke Test
82+
83+
Run a multihost JAX initialization smoke test on both workers:
84+
85+
```bash
86+
gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \
87+
--worker=all \
88+
--internal-ip \
89+
--zone=europe-west2-a \
90+
--quiet \
91+
--command='
92+
set -e
93+
source /dev/shm/maxdiffusion/.venv-tpu/bin/activate
94+
export PYTHONPATH=/dev/shm/maxdiffusion/src:${PYTHONPATH}
95+
python - <<'"'"'PY'"'"'
96+
import socket
97+
import jax
98+
99+
jax.distributed.initialize()
100+
print(
101+
f"host={socket.gethostname()} "
102+
f"process_index={jax.process_index()} "
103+
f"process_count={jax.process_count()} "
104+
f"local_device_count={jax.local_device_count()} "
105+
f"device_count={jax.device_count()}"
106+
)
107+
PY'
108+
```
109+
110+
## Run WAN Comparison Jobs
111+
112+
All commands below use:
113+
114+
- model config: `src/maxdiffusion/configs/base_wan_27b.yml`
115+
- checkpoint: `Wan-AI/Wan2.2-T2V-A14B-Diffusers`
116+
- global batch size: `2`
117+
- per-device batch size: `0.125`
118+
- total devices: `16`
119+
120+
Common environment:
121+
122+
```bash
123+
export REPO_DIR=/dev/shm/maxdiffusion
124+
export VENV=/dev/shm/maxdiffusion/.venv-tpu
125+
export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface
126+
export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub
127+
export HF_HUB_ENABLE_HF_TRANSFER=1
128+
export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax
129+
export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla
130+
export TMPDIR=/dev/shm/maxdiffusion_cache/tmp
131+
export RESULTS_ROOT=/dev/shm/maxdiffusion/bench_results
132+
export CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml
133+
export COMMON_ARGS='height=720 width=1280 num_frames=81 num_inference_steps=40 per_device_batch_size=0.125 output_dir=/dev/shm/maxdiffusion/bench_outputs scan_layers=True enable_profiler=False'
134+
```
135+
136+
Helper to run one job on both workers:
137+
138+
```bash
139+
run_case() {
140+
local run_name="$1"
141+
local attention="$2"
142+
local ici_dp="$3"
143+
local ici_cp="$4"
144+
local ici_tp="$5"
145+
146+
gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \
147+
--worker=all \
148+
--internal-ip \
149+
--zone=europe-west2-a \
150+
--quiet \
151+
--command="
152+
set -e
153+
source $VENV/bin/activate
154+
export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH}
155+
export HF_HOME=$HF_HOME
156+
export HF_HUB_CACHE=$HF_HUB_CACHE
157+
export HF_HUB_ENABLE_HF_TRANSFER=$HF_HUB_ENABLE_HF_TRANSFER
158+
export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR
159+
export XLA_CACHE_DIR=$XLA_CACHE_DIR
160+
export TMPDIR=$TMPDIR
161+
mkdir -p $RESULTS_ROOT /dev/shm/maxdiffusion/bench_outputs
162+
cd $RESULTS_ROOT
163+
python $REPO_DIR/src/maxdiffusion/generate_wan.py \
164+
$CONFIG \
165+
run_name=$run_name \
166+
attention=$attention \
167+
ici_data_parallelism=$ici_dp \
168+
ici_fsdp_parallelism=1 \
169+
ici_context_parallelism=$ici_cp \
170+
ici_tensor_parallelism=$ici_tp \
171+
dcn_data_parallelism=1 \
172+
dcn_fsdp_parallelism=1 \
173+
dcn_context_parallelism=1 \
174+
dcn_tensor_parallelism=1 \
175+
$COMMON_ARGS \
176+
2>&1 | tee $RESULTS_ROOT/$run_name.\$(hostname).log
177+
"
178+
}
179+
```
180+
181+
Run matrix:
182+
183+
```bash
184+
run_case ulysses_dp2_cp8 ulysses 2 8 1
185+
run_case ring_dp2_cp8 ring 2 8 1
186+
run_case ulysses_ring_dp2_cp8_2x4 ulysses_ring 2 2 4
187+
run_case ulysses_ring_dp2_cp8_4x2 ulysses_ring 2 4 2
188+
run_case ulysses_ring_dp1_cp16_4x4 ulysses_ring 1 4 4
189+
```
190+
191+
## Topology Note
192+
193+
TPU v7x exposes dual chiplets as two JAX devices. For `ulysses_ring`, keep the dual-chip pairing inside the Ulysses group by setting `ici_tensor_parallelism >= 2`.
194+
195+
That means:
196+
197+
- `2x4` uses tensor `4`, so the dual-chip pairing is inside the Ulysses side.
198+
- `4x2` uses tensor `2`, so the dual-chip pairing is still inside the Ulysses side.
199+
- `4x4` uses tensor `4`, so the dual-chip pairing is still inside the Ulysses side.
200+
201+
The plain `ring` baseline has no Ulysses group, so it cannot preserve that property by construction.
202+
203+
## Read Results
204+
205+
Pull the timing summary from the per-host logs:
206+
207+
```bash
208+
rg -n "compile_time:|generation_time:|generation time per video:|TIMING SUMMARY" /dev/shm/maxdiffusion/bench_results/*.log
209+
```
210+
211+
If you want a single run's logs from both workers:
212+
213+
```bash
214+
ls -1 /dev/shm/maxdiffusion/bench_results/ulysses_ring_dp2_cp8_2x4.*
215+
```

0 commit comments

Comments
 (0)