|
| 1 | +#!/bin/bash |
| 2 | +#SBATCH --account= |
| 3 | +#SBATCH --nodes=1 |
| 4 | +#SBATCH --partition= |
| 5 | +#SBATCH --ntasks-per-node=1 |
| 6 | +#SBATCH --time=03:55:00 |
| 7 | +#SBATCH --mem=0 |
| 8 | +#SBATCH --job-name= |
| 9 | +#SBATCH --mail-type=FAIL |
| 10 | +#SBATCH --overcommit |
| 11 | +#SBATCH --exclusive |
| 12 | +set -euxo pipefail |
| 13 | + |
| 14 | +# ============================================================================ |
| 15 | +# This script is adapted from the experiment scripts here: |
| 16 | +# https://gitlab-master.nvidia.com/bio-foundation-models/codon-fm/-/tree/405b2315836a9c1c1ae0c5e41d5abcf4f24d6aa8/experiment_scripts/pretraining/encodon_filtered/mlm |
| 17 | +# |
| 18 | +# Modifications: |
| 19 | +# - 'num_jobs' is not supported in the PTL recipe in bionemo-recipes. |
| 20 | +# - '--sharded-state-dict' is not supported in the PTL recipe in bionemo-recipes. It is always 'sharded'. |
| 21 | +# - Added support for selecting the sequence packing method (thd or bshd). |
| 22 | +# - Added support for selecting the distributed strategy (fsdp or ddp). |
| 23 | +# - Added support for selecting the gradient accumulation steps to keep the global batch size constant. |
| 24 | +# - Added support for selecting the attention backend (xformers or pytorch SDPA). |
| 25 | +# ============================================================================ |
| 26 | + |
| 27 | +# Establish or inherit chain ID: manual launch picks SLURM_JOB_ID; trap-resubmit inherits via --export. |
| 28 | +if [ -z "${CHAIN_ID:-}" ]; then |
| 29 | + export CHAIN_ID="${SLURM_JOB_ID}" |
| 30 | + echo "Starting NEW chain: CHAIN_ID=${CHAIN_ID}" |
| 31 | +else |
| 32 | + echo "Continuing chain ${CHAIN_ID} (current job ${SLURM_JOB_ID})" |
| 33 | +fi |
| 34 | + |
| 35 | +# ============================================================================ |
| 36 | +# CodonFM |
| 37 | +# ============================================================================ |
| 38 | + |
| 39 | +BASE_DIR="" |
| 40 | +CONTAINER="" |
| 41 | +DATA_DIR="${BASE_DIR}/data" |
| 42 | +CODE_MOUNT="/workspace/bionemo" |
| 43 | + |
| 44 | + |
| 45 | +: "${WANDB_API_KEY:?Set WANDB_API_KEY in ~/.bash_profile}" |
| 46 | +: "${HUGGING_FACE_HUB_TOKEN:?Set HUGGING_FACE_HUB_TOKEN in ~/.bash_profile}" |
| 47 | +: "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}" |
| 48 | + |
| 49 | +export GLOBAL_BATCH_SIZE=1536 |
| 50 | +export MICRO_BATCH_SIZE=96 |
| 51 | + |
| 52 | +# Experiment parameters |
| 53 | +export CONFIG_NAME=encodon_xx |
| 54 | +export NPROC_PER_NODE=8 |
| 55 | +export DIST_STRATEGY=ddp # fsdp or ddp |
| 56 | + |
| 57 | +# Training |
| 58 | +export NUM_TRAIN_STEPS=100 |
| 59 | +export LEARNING_RATE=7.5e-5 |
| 60 | +export NUM_WORKERS=12 |
| 61 | +export USE_SEQUENCE_PACKING=False |
| 62 | + |
| 63 | +export PRECISION=bf16-mixed |
| 64 | + |
| 65 | +# Logging / W&B |
| 66 | +export LOGGER_FREQUENCY=10 |
| 67 | +export WANDB_PROJECT= |
| 68 | + |
| 69 | +# Attn-backend |
| 70 | +export USE_XFORMERS=1 |
| 71 | +export USE_TRANSFORMER_ENGINE=0 |
| 72 | + |
| 73 | +# Derived: build wandb run name from model size, batch size, and precision recipe |
| 74 | +MODEL_SIZE="${CONFIG_NAME##*_}" |
| 75 | +PRECISION_TAG="${PRECISION}" |
| 76 | + |
| 77 | +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then |
| 78 | + BATCH_TYPE_TAG="thd" |
| 79 | +else |
| 80 | + BATCH_TYPE_TAG="bshd" |
| 81 | +fi |
| 82 | + |
| 83 | +# Derive grad accumulation from GBS / (MBS * GPUs). |
| 84 | +TOTAL_GPUS=$(( NPROC_PER_NODE * SLURM_JOB_NUM_NODES )) |
| 85 | +TOTAL_PER_STEP=$(( MICRO_BATCH_SIZE * TOTAL_GPUS )) |
| 86 | +if [ "${TOTAL_PER_STEP}" -eq 0 ] || [ "$(( GLOBAL_BATCH_SIZE % TOTAL_PER_STEP ))" -ne 0 ]; then |
| 87 | + echo "ERROR: GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE} must be a positive multiple of MICRO_BATCH_SIZE*NPROC_PER_NODE*NODES=${TOTAL_PER_STEP}" >&2 |
| 88 | + exit 1 |
| 89 | +fi |
| 90 | +export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP )) |
| 91 | +echo "Batch sizing: GBS=${GLOBAL_BATCH_SIZE}, MBS=${MICRO_BATCH_SIZE}, NPROC=${NPROC_PER_NODE}, NODES=${SLURM_JOB_NUM_NODES}, GRAD_ACC=${GRAD_ACC_STEPS}" |
| 92 | + |
| 93 | +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}_chain_${CHAIN_ID}" |
| 94 | + |
| 95 | +# Mounts |
| 96 | +RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}" |
| 97 | +CKPT_DIR="${BASE_DIR}/checkpoints/${WANDB_RUN_NAME}" |
| 98 | + |
| 99 | +mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}" |
| 100 | + |
| 101 | +MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints" |
| 102 | + |
| 103 | +# Resolve head node on the host (scontrol is not available inside the container). |
| 104 | +MASTER_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) |
| 105 | +MASTER_PORT=29500 |
| 106 | + |
| 107 | + |
| 108 | +read -r -d '' COMMAND <<'OUTER_EOF' || true |
| 109 | +set -euxo pipefail |
| 110 | +
|
| 111 | +echo "=========================================" |
| 112 | +echo "CodonFM ${CONFIG_NAME} - STRATEGY: ${DIST_STRATEGY} - PRECISION: ${PRECISION_TAG} - CLUSTER: ${CLUSTER_NAME}" |
| 113 | +echo "Job ID: ${SLURM_JOB_ID}" |
| 114 | +echo "Nodes: ${SLURM_JOB_NUM_NODES}" |
| 115 | +echo "=========================================" |
| 116 | +
|
| 117 | +export USE_XFORMERS=${USE_XFORMERS:-0} |
| 118 | +if [ "${USE_XFORMERS}" = "1" ]; then |
| 119 | + echo "Using Xformers" |
| 120 | +else |
| 121 | + echo "Using PyTorch SDPA attention" |
| 122 | +fi |
| 123 | +
|
| 124 | +# cuDNN fused-attn sub-backend 1 OOMs on Blackwell (sm_103) with THD+padding (TE 2.12 / cuDNN 9.19); force flash-attn varlen. |
| 125 | +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then |
| 126 | + export NVTE_FUSED_ATTN=0 |
| 127 | + EXTRA_ARGS="--collate_fn thd --attn_input_format thd" |
| 128 | +else |
| 129 | + EXTRA_ARGS="--collate_fn bshd --attn_input_format bshd" |
| 130 | +fi |
| 131 | +
|
| 132 | +# Pick training script based on distributed strategy. |
| 133 | +case "${DIST_STRATEGY}" in |
| 134 | + fsdp) |
| 135 | + EXTRA_ARGS="${EXTRA_ARGS} --enable_fsdp" |
| 136 | + ;; |
| 137 | + ddp) |
| 138 | + EXTRA_ARGS="${EXTRA_ARGS}" |
| 139 | + ;; |
| 140 | + *) |
| 141 | + echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2 |
| 142 | + exit 1 |
| 143 | + ;; |
| 144 | +esac |
| 145 | +
|
| 146 | +if [ "${PRECISION}" = "bf16-mixed" ]; then |
| 147 | + EXTRA_ARGS="${EXTRA_ARGS} --bf16" |
| 148 | +fi |
| 149 | +
|
| 150 | +if [ "${USE_TRANSFORMER_ENGINE}" = "1" ]; then |
| 151 | + EXTRA_ARGS="${EXTRA_ARGS} --use_transformer_engine" |
| 152 | +fi |
| 153 | +
|
| 154 | +torchrun \ |
| 155 | + --nproc_per_node=${NPROC_PER_NODE} \ |
| 156 | + --rdzv_id=${SLURM_JOB_ID} \ |
| 157 | + --rdzv_backend=c10d \ |
| 158 | + --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \ |
| 159 | + --nnodes=${SLURM_JOB_NUM_NODES} \ |
| 160 | + --node-rank=${SLURM_NODEID} \ |
| 161 | + -m src.runner pretrain \ |
| 162 | + --exp_name ${WANDB_RUN_NAME} \ |
| 163 | + --model_name ${CONFIG_NAME} \ |
| 164 | + --data_path /workspace/bionemo/data/processed_unfiltered/ \ |
| 165 | + --process_item mlm_memmap \ |
| 166 | + --dataset_name CodonMemmapDataset \ |
| 167 | + --lr ${LEARNING_RATE} \ |
| 168 | + --num_gpus ${NPROC_PER_NODE} \ |
| 169 | + --num_nodes ${SLURM_JOB_NUM_NODES} \ |
| 170 | + --train_batch_size ${MICRO_BATCH_SIZE} \ |
| 171 | + --val_batch_size ${MICRO_BATCH_SIZE} \ |
| 172 | + --num_workers ${NUM_WORKERS} \ |
| 173 | + ${EXTRA_ARGS} \ |
| 174 | + --split_name_prefix nopathogen \ |
| 175 | + --taxid_exclusion_file /workspace/bionemo/data/taxids_to_remove.json \ |
| 176 | + --enable_wandb \ |
| 177 | + --project_name ${WANDB_PROJECT} \ |
| 178 | + --entity clara-discovery \ |
| 179 | + --gradient_accumulation_steps ${GRAD_ACC_STEPS} \ |
| 180 | + --max_steps ${NUM_TRAIN_STEPS} \ |
| 181 | + --log_every_n_steps ${LOGGER_FREQUENCY} |
| 182 | +
|
| 183 | +echo "=========================================" |
| 184 | +echo "Training complete!" |
| 185 | +echo "=========================================" |
| 186 | +OUTER_EOF |
| 187 | + |
| 188 | +# Inject environment variables into the command. |
| 189 | +COMMAND="export DIST_STRATEGY=\"${DIST_STRATEGY}\"; ${COMMAND}" |
| 190 | +COMMAND="export PRECISION_TAG=\"${PRECISION_TAG}\"; ${COMMAND}" |
| 191 | +COMMAND="export CLUSTER_NAME=\"${CLUSTER_NAME}\"; ${COMMAND}" |
| 192 | +COMMAND="export NPROC_PER_NODE=\"${NPROC_PER_NODE}\"; ${COMMAND}" |
| 193 | +COMMAND="export CONFIG_NAME=\"${CONFIG_NAME}\"; ${COMMAND}" |
| 194 | +COMMAND="export LOGGER_FREQUENCY=\"${LOGGER_FREQUENCY}\"; ${COMMAND}" |
| 195 | +COMMAND="export NUM_TRAIN_STEPS=\"${NUM_TRAIN_STEPS}\"; ${COMMAND}" |
| 196 | +COMMAND="export GLOBAL_BATCH_SIZE=\"${GLOBAL_BATCH_SIZE}\"; ${COMMAND}" |
| 197 | +COMMAND="export MICRO_BATCH_SIZE=\"${MICRO_BATCH_SIZE}\"; ${COMMAND}" |
| 198 | +COMMAND="export GRAD_ACC_STEPS=\"${GRAD_ACC_STEPS}\"; ${COMMAND}" |
| 199 | +COMMAND="export LEARNING_RATE=\"${LEARNING_RATE}\"; ${COMMAND}" |
| 200 | +COMMAND="export NUM_WORKERS=\"${NUM_WORKERS}\"; ${COMMAND}" |
| 201 | +COMMAND="export USE_SEQUENCE_PACKING=\"${USE_SEQUENCE_PACKING}\"; ${COMMAND}" |
| 202 | +COMMAND="export PRECISION=\"${PRECISION}\"; ${COMMAND}" |
| 203 | +COMMAND="export WANDB_RUN_NAME=\"${WANDB_RUN_NAME}\"; ${COMMAND}" |
| 204 | +COMMAND="export WANDB_PROJECT=\"${WANDB_PROJECT}\"; ${COMMAND}" |
| 205 | +COMMAND="export USE_XFORMERS=\"${USE_XFORMERS}\"; ${COMMAND}" |
| 206 | +COMMAND="export MASTER_ADDR=\"${MASTER_ADDR}\"; ${COMMAND}" |
| 207 | +COMMAND="export MASTER_PORT=\"${MASTER_PORT}\"; ${COMMAND}" |
| 208 | +COMMAND="export USE_TRANSFORMER_ENGINE=\"${USE_TRANSFORMER_ENGINE}\"; ${COMMAND}" |
| 209 | +COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}" |
| 210 | +COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" |
| 211 | +COMMAND="export HF_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" |
| 212 | + |
| 213 | + |
| 214 | +echo "Launching: ${WANDB_RUN_NAME}" |
| 215 | + |
| 216 | +# AUTO-CHAIN: resubmit on timeout. |
| 217 | +trap ' |
| 218 | + rc=$? |
| 219 | + if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then |
| 220 | + echo "Timed out (rc=$rc) — resubmitting chain ${CHAIN_ID}." |
| 221 | + sbatch --dependency=singleton \ |
| 222 | + --export=ALL,CHAIN_ID="${CHAIN_ID}" \ |
| 223 | + "${BASH_SOURCE[0]}" |
| 224 | + elif [ "$rc" -eq 0 ]; then |
| 225 | + echo "Training finished cleanly — chain ${CHAIN_ID} ends." |
| 226 | + else |
| 227 | + echo "Real error (rc=$rc) — chain ${CHAIN_ID} ends so you can investigate." |
| 228 | + fi |
| 229 | + ' EXIT |
| 230 | + |
| 231 | +srun \ |
| 232 | + --output "${RESULTS_DIR}/slurm-%j-%n.out" \ |
| 233 | + --error "${RESULTS_DIR}/error-%j-%n.out" \ |
| 234 | + --container-image "${CONTAINER}" \ |
| 235 | + --container-mounts "${MOUNTS}" \ |
| 236 | + bash -c "${COMMAND}" |
0 commit comments