|
| 1 | +#!/usr/bin/env bash |
| 2 | +set -euo pipefail |
| 3 | + |
| 4 | +# ── config (TPU v7-8, single host, 8 chips) ────────────────────────────────── |
| 5 | +REPO_DIR=/mnt/data/sagarchapara/workspace/maxdiffusion |
| 6 | +VENV=/mnt/data/sagarchapara/workspace/venv |
| 7 | +CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml |
| 8 | +RESULTS_ROOT=$REPO_DIR/bench_results |
| 9 | +OUTPUT_ROOT=$REPO_DIR/bench_outputs |
| 10 | + |
| 11 | +PRETRAINED_ORBAX_DIR=/mnt/data/sagarchapara/workspace/wan22_orbax_cache |
| 12 | +mkdir -p "$PRETRAINED_ORBAX_DIR" |
| 13 | + |
| 14 | +export LIBTPU_INIT_ARGS=\ |
| 15 | +'--xla_tpu_dvfs_p_state=7 '\ |
| 16 | +'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\ |
| 17 | +'--xla_tpu_megacore_fusion_allow_ags=false '\ |
| 18 | +'--xla_enable_async_collective_permute=true '\ |
| 19 | +'--xla_tpu_enable_ag_backward_pipelining=true '\ |
| 20 | +'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\ |
| 21 | +'--xla_tpu_data_parallel_opt_different_sized_ops=true '\ |
| 22 | +'--xla_tpu_enable_async_collective_fusion=true '\ |
| 23 | +'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\ |
| 24 | +'--xla_tpu_overlap_compute_collective_tc=true '\ |
| 25 | +'--xla_enable_async_all_gather=true '\ |
| 26 | +'--xla_tpu_scoped_vmem_limit_kib=65536 '\ |
| 27 | +'--xla_tpu_enable_async_all_to_all=true '\ |
| 28 | +'--xla_tpu_enable_latency_hiding_scheduler=true '\ |
| 29 | +'--xla_tpu_enable_all_experimental_scheduler_features=true '\ |
| 30 | +'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\ |
| 31 | +'--xla_tpu_host_transfer_overlap_limit=24 '\ |
| 32 | +'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\ |
| 33 | +'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\ |
| 34 | +'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\ |
| 35 | +'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\ |
| 36 | +'--xla_max_concurrent_host_send_recv=100 '\ |
| 37 | +'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\ |
| 38 | +'--xla_latency_hiding_scheduler_rerun=5 '\ |
| 39 | +'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\ |
| 40 | +'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\ |
| 41 | +'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\ |
| 42 | +'--xla_tpu_assign_all_reduce_scatter_layout=true '\ |
| 43 | +'--xla_max_concurrent_async_collective_permutes=16 '\ |
| 44 | +'--xla_tpu_enable_ici_ag_pipelining=true' |
| 45 | + |
| 46 | +source "$VENV/bin/activate" |
| 47 | +export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-} |
| 48 | +export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface |
| 49 | +export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub |
| 50 | +export HF_HUB_ENABLE_HF_TRANSFER=1 |
| 51 | +export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax |
| 52 | +export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla |
| 53 | +export TMPDIR=/dev/shm/maxdiffusion_cache/tmp |
| 54 | +mkdir -p "$TMPDIR" "$HF_HOME" "$HF_HUB_CACHE" "$JAX_COMPILATION_CACHE_DIR" "$XLA_CACHE_DIR" |
| 55 | + |
| 56 | +# ── helper (single host - no SSH) ──────────────────────────────────────────── |
| 57 | +# run_case <run_name> <attention> <ici_dp> <ici_cp> <unused> <per_device_batch_size> |
| 58 | +run_case() { |
| 59 | + local run_name="$1" |
| 60 | + local attention="$2" |
| 61 | + local ici_dp="$3" |
| 62 | + local ici_a="$4" |
| 63 | + local ici_b="$5" |
| 64 | + local pdb="$6" |
| 65 | + |
| 66 | + local results_dir="$RESULTS_ROOT/$run_name" |
| 67 | + rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name" |
| 68 | + mkdir -p "$results_dir" |
| 69 | + rm -f /tmp/libtpu_lockfile |
| 70 | + mkdir -p "$TMPDIR" |
| 71 | + |
| 72 | + local ici_cp="$ici_a" |
| 73 | + local ici_tp=1 |
| 74 | + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──" |
| 75 | + |
| 76 | + # Profiler only for bs=2 (pdb=0.25 with 8 devices) |
| 77 | + local profiler_args="enable_profiler=False" |
| 78 | + if [[ "$pdb" == "0.25" ]]; then |
| 79 | + profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10" |
| 80 | + fi |
| 81 | + |
| 82 | + local common_args="run_name=$run_name \ |
| 83 | + attention=$attention \ |
| 84 | + ici_data_parallelism=$ici_dp \ |
| 85 | + ici_fsdp_parallelism=1 \ |
| 86 | + ici_context_parallelism=$ici_cp \ |
| 87 | + ici_tensor_parallelism=$ici_tp \ |
| 88 | + dcn_data_parallelism=1 \ |
| 89 | + dcn_fsdp_parallelism=1 \ |
| 90 | + dcn_context_parallelism=1 \ |
| 91 | + dcn_tensor_parallelism=1 \ |
| 92 | + pretrained_orbax_dir=$PRETRAINED_ORBAX_DIR \ |
| 93 | + height=720 width=1280 num_frames=81 num_inference_steps=40 \ |
| 94 | + per_device_batch_size=$pdb \ |
| 95 | + output_dir=$OUTPUT_ROOT \ |
| 96 | + scan_layers=True \ |
| 97 | + write_metrics=False \ |
| 98 | + write_timing_metrics=False \ |
| 99 | + $profiler_args" |
| 100 | + |
| 101 | + cd "$results_dir" |
| 102 | + set +e |
| 103 | + python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \ |
| 104 | + "$CONFIG" \ |
| 105 | + $common_args \ |
| 106 | + 2>&1 | tee "$results_dir/worker0.log" |
| 107 | + local status=$? |
| 108 | + set -e |
| 109 | + |
| 110 | + echo "[$(date -u +%T)] $run_name done — status=$status" |
| 111 | + echo "────────────────────────────────────────────────────────────────────────" |
| 112 | +} |
| 113 | + |
| 114 | +# ── run matrix (TPU v7-8: 1 host × 8 chips = 8 devices) ────────────────────── |
| 115 | +# Parallelism rule: dp × fsdp × cp × tp = 8 |
| 116 | +# |
| 117 | +# All modes: dp=2, cp=4 (2×4=8) |
| 118 | +# |
| 119 | +# Batch sizes: pdb × 8 devices = total_bs |
| 120 | +# 0.125 → bs1, 0.25 → bs2, 0.5 → bs4 |
| 121 | + |
| 122 | +for pdb in 0.125 0.25 0.5; do |
| 123 | + bs=$(python3 -c "print(int($pdb * 8))") |
| 124 | + |
| 125 | + run_case "flash_dp2_cp4_bs${bs}" flash 2 4 1 $pdb |
| 126 | + run_case "tokamax_ring_dp2_cp4_bs${bs}" tokamax_ring 2 4 1 $pdb |
| 127 | + run_case "ulysses_dp2_cp4_bs${bs}" ulysses 2 4 1 $pdb |
| 128 | + run_case "ulysses_ring_dp2_cp4_bs${bs}" ulysses_ring 2 4 1 $pdb |
| 129 | +done |
| 130 | + |
| 131 | +echo "[$(date -u +%T)] All benchmark runs complete." |
0 commit comments