|
| 1 | +#!/usr/bin/env bash |
| 2 | +set -euo pipefail |
| 3 | + |
| 4 | +# ── config ──────────────────────────────────────────────────────────────────── |
| 5 | +REPO_DIR=/dev/shm/maxdiffusion |
| 6 | +VENV=$REPO_DIR/.venv-tpu |
| 7 | +CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml |
| 8 | +RESULTS_ROOT=$REPO_DIR/bench_results |
| 9 | +OUTPUT_ROOT=$REPO_DIR/bench_outputs |
| 10 | + |
| 11 | +WORKER1_USER=sa_112155357684894056033 |
| 12 | +WORKER1_IP=10.154.0.59 |
| 13 | +WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB |
| 14 | +SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine |
| 15 | +SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts |
| 16 | +SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \ |
| 17 | + -o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \ |
| 18 | + -o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS" |
| 19 | + |
| 20 | +export LIBTPU_INIT_ARGS=\ |
| 21 | +'--xla_tpu_dvfs_p_state=7 '\ |
| 22 | +'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\ |
| 23 | +'--xla_tpu_megacore_fusion_allow_ags=false '\ |
| 24 | +'--xla_enable_async_collective_permute=true '\ |
| 25 | +'--xla_tpu_enable_ag_backward_pipelining=true '\ |
| 26 | +'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\ |
| 27 | +'--xla_tpu_data_parallel_opt_different_sized_ops=true '\ |
| 28 | +'--xla_tpu_enable_async_collective_fusion=true '\ |
| 29 | +'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\ |
| 30 | +'--xla_tpu_overlap_compute_collective_tc=true '\ |
| 31 | +'--xla_enable_async_all_gather=true '\ |
| 32 | +'--xla_tpu_scoped_vmem_limit_kib=65536 '\ |
| 33 | +'--xla_tpu_enable_async_all_to_all=true '\ |
| 34 | +'--xla_tpu_enable_latency_hiding_scheduler=true '\ |
| 35 | +'--xla_tpu_enable_all_experimental_scheduler_features=true '\ |
| 36 | +'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\ |
| 37 | +'--xla_tpu_host_transfer_overlap_limit=24 '\ |
| 38 | +'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\ |
| 39 | +'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\ |
| 40 | +'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\ |
| 41 | +'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\ |
| 42 | +'--xla_max_concurrent_host_send_recv=100 '\ |
| 43 | +'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\ |
| 44 | +'--xla_latency_hiding_scheduler_rerun=5 '\ |
| 45 | +'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\ |
| 46 | +'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\ |
| 47 | +'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\ |
| 48 | +'--xla_tpu_assign_all_reduce_scatter_layout=true '\ |
| 49 | +'--xla_max_concurrent_async_collective_permutes=16 '\ |
| 50 | +'--xla_tpu_enable_ici_ag_pipelining=true' |
| 51 | + |
| 52 | +source "$VENV/bin/activate" |
| 53 | +export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-} |
| 54 | +export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface |
| 55 | +export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub |
| 56 | +export HF_HUB_ENABLE_HF_TRANSFER=1 |
| 57 | +export HF_HUB_OFFLINE=1 |
| 58 | +export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax |
| 59 | +export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla |
| 60 | +export TMPDIR=/dev/shm/maxdiffusion_cache/tmp |
| 61 | + |
| 62 | +# ── helper ──────────────────────────────────────────────────────────────────── |
| 63 | +# run_case <run_name> <attention> <ici_dp> <ici_a> <ici_b> <per_device_batch_size> |
| 64 | +# pure (flash/tokamax_ring/ulysses): ici_a=cp, ici_b unused |
| 65 | +# hybrid (ulysses_ring): ici_a=ring, ici_b=ulysses |
| 66 | +run_case() { |
| 67 | + local run_name="$1" |
| 68 | + local attention="$2" |
| 69 | + local ici_dp="$3" |
| 70 | + local ici_a="$4" |
| 71 | + local ici_b="$5" |
| 72 | + local pdb="$6" # per_device_batch_size |
| 73 | + |
| 74 | + local results_dir="$RESULTS_ROOT/$run_name" |
| 75 | + rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name" |
| 76 | + mkdir -p "$results_dir" |
| 77 | + rm -f /tmp/libtpu_lockfile |
| 78 | + mkdir -p "$TMPDIR" |
| 79 | + |
| 80 | + local ici_cp ici_tp ici_ring ici_ulysses |
| 81 | + if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then |
| 82 | + ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b" |
| 83 | + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses pdb=$pdb) ──" |
| 84 | + else |
| 85 | + ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1 |
| 86 | + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──" |
| 87 | + fi |
| 88 | + |
| 89 | + # Only capture profiler for bs=2 (per_device_batch_size=0.125) |
| 90 | + local profiler_args="enable_profiler=False" |
| 91 | + if [[ "$pdb" == "0.125" ]]; then |
| 92 | + profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10" |
| 93 | + fi |
| 94 | + |
| 95 | + local common_args="run_name=$run_name \ |
| 96 | + attention=$attention \ |
| 97 | + ici_data_parallelism=$ici_dp \ |
| 98 | + ici_fsdp_parallelism=1 \ |
| 99 | + ici_context_parallelism=$ici_cp \ |
| 100 | + ici_tensor_parallelism=$ici_tp \ |
| 101 | + ici_ring_parallelism=$ici_ring \ |
| 102 | + ici_ulysses_parallelism=$ici_ulysses \ |
| 103 | + dcn_data_parallelism=1 \ |
| 104 | + dcn_fsdp_parallelism=1 \ |
| 105 | + dcn_context_parallelism=1 \ |
| 106 | + dcn_tensor_parallelism=1 \ |
| 107 | + dcn_ring_parallelism=1 \ |
| 108 | + dcn_ulysses_parallelism=1 \ |
| 109 | + height=720 width=1280 num_frames=81 num_inference_steps=40 \ |
| 110 | + per_device_batch_size=$pdb \ |
| 111 | + output_dir=$OUTPUT_ROOT \ |
| 112 | + scan_layers=True \ |
| 113 | + write_metrics=False \ |
| 114 | + write_timing_metrics=False \ |
| 115 | + $profiler_args" |
| 116 | + |
| 117 | + local remote_cmd |
| 118 | + remote_cmd="$(printf '%q' "set -euo pipefail |
| 119 | +export LIBTPU_INIT_ARGS='$LIBTPU_INIT_ARGS' |
| 120 | +source $VENV/bin/activate |
| 121 | +export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH:-} |
| 122 | +export HF_HOME=$HF_HOME |
| 123 | +export HF_HUB_CACHE=$HF_HUB_CACHE |
| 124 | +export HF_HUB_ENABLE_HF_TRANSFER=1 |
| 125 | +export HF_HUB_OFFLINE=1 |
| 126 | +export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR |
| 127 | +export XLA_CACHE_DIR=$XLA_CACHE_DIR |
| 128 | +export TMPDIR=$TMPDIR |
| 129 | +rm -f /tmp/libtpu_lockfile |
| 130 | +mkdir -p $TMPDIR $results_dir $OUTPUT_ROOT |
| 131 | +cd $results_dir |
| 132 | +python -u $REPO_DIR/src/maxdiffusion/generate_wan.py $CONFIG $common_args 2>&1 | tee $results_dir/worker1.log")" |
| 133 | + |
| 134 | + /usr/bin/ssh $SSH_OPTS "$WORKER1_USER@$WORKER1_IP" "bash -lc $remote_cmd" & |
| 135 | + local remote_pid=$! |
| 136 | + |
| 137 | + cd "$results_dir" |
| 138 | + set +e |
| 139 | + python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \ |
| 140 | + "$CONFIG" \ |
| 141 | + $common_args \ |
| 142 | + 2>&1 | tee "$results_dir/worker0.log" |
| 143 | + local local_status=$? |
| 144 | + |
| 145 | + wait "$remote_pid" |
| 146 | + local remote_status=$? |
| 147 | + set -e |
| 148 | + |
| 149 | + echo "[$(date -u +%T)] $run_name done — local=$local_status remote=$remote_status" |
| 150 | + echo "────────────────────────────────────────────────────────────────────────" |
| 151 | +} |
| 152 | + |
| 153 | +# ── run matrix ──────────────────────────────────────────────────────────────── |
| 154 | +# Columns: run_name attention dp a b per_device_bs |
| 155 | +# |
| 156 | +# All runs: 16 total devices (2 hosts × 8), dp=2, cp=8 (or ring×ulysses=8) |
| 157 | +# Batch sizes: 0.0625=bs1, 0.125=bs2, 0.25=bs4 |
| 158 | + |
| 159 | +for pdb in 0.0625 0.125 0.25 0.5; do |
| 160 | + bs=$(python3 -c "print(int($pdb * 16))") |
| 161 | + run_case "flash_dp2_cp8_bs${bs}" flash 2 8 1 $pdb |
| 162 | + run_case "tokamax_ring_dp2_cp8_bs${bs}" tokamax_ring 2 8 1 $pdb |
| 163 | + run_case "ulysses_dp2_cp8_bs${bs}" ulysses 2 8 1 $pdb |
| 164 | + run_case "ulysses_ring_dp2_r2u4_bs${bs}" ulysses_ring 2 2 4 $pdb # ring=2, ulysses=4, dp=2 |
| 165 | + run_case "ulysses_ring_dp2_r4u2_bs${bs}" ulysses_ring 2 4 2 $pdb # ring=4, ulysses=2, dp=2 |
| 166 | + run_case "ulysses_ring_dp1_r4u4_bs${bs}" ulysses_ring 1 4 4 $pdb # ring=4, ulysses=4, dp=1 |
| 167 | +done |
| 168 | + |
| 169 | +echo "[$(date -u +%T)] All benchmark runs complete." |
0 commit comments