Skip to content

Commit 2582cf1

Browse files
committed
Add TPU v7-8 attention benchmark script
1 parent 0fe15fd commit 2582cf1

1 file changed

Lines changed: 131 additions & 0 deletions

File tree

bench_attn_v78.sh

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
# ── config (TPU v7-8, single host, 8 chips) ──────────────────────────────────
5+
REPO_DIR=/mnt/data/sagarchapara/workspace/maxdiffusion
6+
VENV=/mnt/data/sagarchapara/workspace/venv
7+
CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml
8+
RESULTS_ROOT=$REPO_DIR/bench_results
9+
OUTPUT_ROOT=$REPO_DIR/bench_outputs
10+
11+
PRETRAINED_ORBAX_DIR=/mnt/data/sagarchapara/workspace/wan22_orbax_cache
12+
mkdir -p "$PRETRAINED_ORBAX_DIR"
13+
14+
export LIBTPU_INIT_ARGS=\
15+
'--xla_tpu_dvfs_p_state=7 '\
16+
'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\
17+
'--xla_tpu_megacore_fusion_allow_ags=false '\
18+
'--xla_enable_async_collective_permute=true '\
19+
'--xla_tpu_enable_ag_backward_pipelining=true '\
20+
'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\
21+
'--xla_tpu_data_parallel_opt_different_sized_ops=true '\
22+
'--xla_tpu_enable_async_collective_fusion=true '\
23+
'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\
24+
'--xla_tpu_overlap_compute_collective_tc=true '\
25+
'--xla_enable_async_all_gather=true '\
26+
'--xla_tpu_scoped_vmem_limit_kib=65536 '\
27+
'--xla_tpu_enable_async_all_to_all=true '\
28+
'--xla_tpu_enable_latency_hiding_scheduler=true '\
29+
'--xla_tpu_enable_all_experimental_scheduler_features=true '\
30+
'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\
31+
'--xla_tpu_host_transfer_overlap_limit=24 '\
32+
'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\
33+
'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\
34+
'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\
35+
'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\
36+
'--xla_max_concurrent_host_send_recv=100 '\
37+
'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\
38+
'--xla_latency_hiding_scheduler_rerun=5 '\
39+
'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\
40+
'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\
41+
'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\
42+
'--xla_tpu_assign_all_reduce_scatter_layout=true '\
43+
'--xla_max_concurrent_async_collective_permutes=16 '\
44+
'--xla_tpu_enable_ici_ag_pipelining=true'
45+
46+
source "$VENV/bin/activate"
47+
export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-}
48+
export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface
49+
export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub
50+
export HF_HUB_ENABLE_HF_TRANSFER=1
51+
export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax
52+
export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla
53+
export TMPDIR=/dev/shm/maxdiffusion_cache/tmp
54+
mkdir -p "$TMPDIR" "$HF_HOME" "$HF_HUB_CACHE" "$JAX_COMPILATION_CACHE_DIR" "$XLA_CACHE_DIR"
55+
56+
# ── helper (single host - no SSH) ────────────────────────────────────────────
57+
# run_case <run_name> <attention> <ici_dp> <ici_cp> <unused> <per_device_batch_size>
58+
run_case() {
59+
local run_name="$1"
60+
local attention="$2"
61+
local ici_dp="$3"
62+
local ici_a="$4"
63+
local ici_b="$5"
64+
local pdb="$6"
65+
66+
local results_dir="$RESULTS_ROOT/$run_name"
67+
rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name"
68+
mkdir -p "$results_dir"
69+
rm -f /tmp/libtpu_lockfile
70+
mkdir -p "$TMPDIR"
71+
72+
local ici_cp="$ici_a"
73+
local ici_tp=1
74+
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──"
75+
76+
# Profiler only for bs=2 (pdb=0.25 with 8 devices)
77+
local profiler_args="enable_profiler=False"
78+
if [[ "$pdb" == "0.25" ]]; then
79+
profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10"
80+
fi
81+
82+
local common_args="run_name=$run_name \
83+
attention=$attention \
84+
ici_data_parallelism=$ici_dp \
85+
ici_fsdp_parallelism=1 \
86+
ici_context_parallelism=$ici_cp \
87+
ici_tensor_parallelism=$ici_tp \
88+
dcn_data_parallelism=1 \
89+
dcn_fsdp_parallelism=1 \
90+
dcn_context_parallelism=1 \
91+
dcn_tensor_parallelism=1 \
92+
pretrained_orbax_dir=$PRETRAINED_ORBAX_DIR \
93+
height=720 width=1280 num_frames=81 num_inference_steps=40 \
94+
per_device_batch_size=$pdb \
95+
output_dir=$OUTPUT_ROOT \
96+
scan_layers=True \
97+
write_metrics=False \
98+
write_timing_metrics=False \
99+
$profiler_args"
100+
101+
cd "$results_dir"
102+
set +e
103+
python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \
104+
"$CONFIG" \
105+
$common_args \
106+
2>&1 | tee "$results_dir/worker0.log"
107+
local status=$?
108+
set -e
109+
110+
echo "[$(date -u +%T)] $run_name done — status=$status"
111+
echo "────────────────────────────────────────────────────────────────────────"
112+
}
113+
114+
# ── run matrix (TPU v7-8: 1 host × 8 chips = 8 devices) ──────────────────────
115+
# Parallelism rule: dp × fsdp × cp × tp = 8
116+
#
117+
# All modes: dp=2, cp=4 (2×4=8)
118+
#
119+
# Batch sizes: pdb × 8 devices = total_bs
120+
# 0.125 → bs1, 0.25 → bs2, 0.5 → bs4
121+
122+
for pdb in 0.125 0.25 0.5; do
123+
bs=$(python3 -c "print(int($pdb * 8))")
124+
125+
run_case "flash_dp2_cp4_bs${bs}" flash 2 4 1 $pdb
126+
run_case "tokamax_ring_dp2_cp4_bs${bs}" tokamax_ring 2 4 1 $pdb
127+
run_case "ulysses_dp2_cp4_bs${bs}" ulysses 2 4 1 $pdb
128+
run_case "ulysses_ring_dp2_cp4_bs${bs}" ulysses_ring 2 4 1 $pdb
129+
done
130+
131+
echo "[$(date -u +%T)] All benchmark runs complete."

0 commit comments

Comments
 (0)