|
| 1 | +#!/bin/bash |
| 2 | +#SBATCH --account= |
| 3 | +#SBATCH --nodes=1 |
| 4 | +#SBATCH --partition= |
| 5 | +#SBATCH --ntasks-per-node=1 |
| 6 | +#SBATCH --time=03:55:00 |
| 7 | +#SBATCH --mem=0 |
| 8 | +#SBATCH --job-name= |
| 9 | +#SBATCH --mail-type=FAIL |
| 10 | +#SBATCH --overcommit |
| 11 | +#SBATCH --exclusive |
| 12 | +set -euxo pipefail |
| 13 | + |
| 14 | +# ============================================================================ |
| 15 | +# Codon 1B |
| 16 | +# ============================================================================ |
| 17 | + |
| 18 | +BASE_DIR="" |
| 19 | +CONTAINER="" |
| 20 | +DATA_DIR="${BASE_DIR}/data" |
| 21 | +CODE_MOUNT="/workspace/bionemo" |
| 22 | + |
| 23 | + |
| 24 | +: "${WANDB_API_KEY:?Set WANDB_API_KEY in ~/.bash_profile}" |
| 25 | +: "${HUGGING_FACE_HUB_TOKEN:?Set HUGGING_FACE_HUB_TOKEN in ~/.bash_profile}" |
| 26 | +: "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}" |
| 27 | + |
| 28 | +# Experiment parameters |
| 29 | +export CONFIG_NAME=encodon_1b |
| 30 | +export NPROC_PER_NODE=8 |
| 31 | +export DIST_STRATEGY=ddp # fsdp or ddp |
| 32 | + |
| 33 | +# Training |
| 34 | +export NUM_TRAIN_STEPS=1000 |
| 35 | +export MICRO_BATCH_SIZE=31 |
| 36 | +export LEARNING_RATE=7.5e-5 |
| 37 | +export NUM_WORKERS=1 |
| 38 | +export USE_SEQUENCE_PACKING=False |
| 39 | +# Precision mode: one of fp32, bf16, bf16-mixed. bf16-mixed matches the reference codonfm `--bf16`. |
| 40 | +export PRECISION=bf16-mixed |
| 41 | +# Only used for FSDP2 + bf16-mixed. One of fp32, bf16. |
| 42 | +export GRAD_REDUCE_TYPE=fp32 |
| 43 | +export NUM_WARMUP_STEPS=50 |
| 44 | + |
| 45 | +# Logging / W&B |
| 46 | +export LOGGER_FREQUENCY=10 |
| 47 | +export WANDB_PROJECT= |
| 48 | + |
| 49 | +# Checkpointing |
| 50 | +export SAVE_FINAL_MODEL=True |
| 51 | +export SAVE_EVERY_N_STEPS=100000 |
| 52 | +export RESUME_FROM_CHECKPOINT=True |
| 53 | + |
| 54 | +# Hydra |
| 55 | +export HYDRA_RUN_DIR=1b_test |
| 56 | + |
| 57 | +# Quantization / FP8 |
| 58 | +export QUANT_STATS_ENABLED=False |
| 59 | +export FP8_ENABLED=False |
| 60 | +export FP8_RECIPE=transformer_engine.common.recipe.MXFP8BlockScaling |
| 61 | +export FP8_FORMAT=E4M3 |
| 62 | + |
| 63 | +# Derived: build wandb run name from model size, batch size, and precision recipe |
| 64 | +MODEL_SIZE="${CONFIG_NAME##*_}" |
| 65 | +if [ "${FP8_ENABLED}" = "True" ]; then |
| 66 | + RECIPE_SHORT="${FP8_RECIPE##*.}" |
| 67 | + RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}" |
| 68 | + RECIPE_SHORT="${RECIPE_SHORT%Scaling}" |
| 69 | + PRECISION_TAG="${PRECISION}_${RECIPE_SHORT,,}_${FP8_FORMAT,,}" |
| 70 | +else |
| 71 | + PRECISION_TAG="${PRECISION}" |
| 72 | +fi |
| 73 | + |
| 74 | +if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then |
| 75 | + BATCH_TYPE_TAG="thd" |
| 76 | +else |
| 77 | + BATCH_TYPE_TAG="bshd" |
| 78 | +fi |
| 79 | + |
| 80 | +export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}" |
| 81 | + |
| 82 | +# Mounts |
| 83 | +RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}" |
| 84 | +CKPT_DIR="${BASE_DIR}/checkpoints/${WANDB_RUN_NAME}" |
| 85 | + |
| 86 | +mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}" |
| 87 | + |
| 88 | +MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints" |
| 89 | + |
| 90 | + |
| 91 | +read -r -d '' COMMAND <<'OUTER_EOF' || true |
| 92 | +set -euxo pipefail |
| 93 | +
|
| 94 | +echo "=========================================" |
| 95 | +echo "CodonFM ${CONFIG_NAME} - STRATEGY: ${DIST_STRATEGY} - PRECISION: ${PRECISION_TAG} - CLUSTER: ${CLUSTER_NAME}" |
| 96 | +echo "Job ID: ${SLURM_JOB_ID}" |
| 97 | +echo "Nodes: ${SLURM_JOB_NUM_NODES}" |
| 98 | +echo "=========================================" |
| 99 | +
|
| 100 | +# Pick training script based on distributed strategy. |
| 101 | +case "${DIST_STRATEGY}" in |
| 102 | + fsdp) |
| 103 | + TRAIN_SCRIPT=train_fsdp2.py |
| 104 | + ;; |
| 105 | + ddp) |
| 106 | + TRAIN_SCRIPT=train_ddp.py |
| 107 | + ;; |
| 108 | + *) |
| 109 | + echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2 |
| 110 | + exit 1 |
| 111 | + ;; |
| 112 | +esac |
| 113 | +
|
| 114 | +torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \ |
| 115 | + --config-name ${CONFIG_NAME} \ |
| 116 | + quant_stats_config.enabled=${QUANT_STATS_ENABLED} \ |
| 117 | + logger.frequency=${LOGGER_FREQUENCY} \ |
| 118 | + num_train_steps=${NUM_TRAIN_STEPS} \ |
| 119 | + dataset.micro_batch_size=${MICRO_BATCH_SIZE} \ |
| 120 | + adamw_kwargs.lr=${LEARNING_RATE} \ |
| 121 | + dataset.num_workers=${NUM_WORKERS} \ |
| 122 | + dataset.data_path=/workspace/bionemo/data/processed_unfiltered/ \ |
| 123 | + use_sequence_packing=${USE_SEQUENCE_PACKING} \ |
| 124 | + precision=${PRECISION} \ |
| 125 | + grad_reduce_type=${GRAD_REDUCE_TYPE} \ |
| 126 | + lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \ |
| 127 | + wandb_init_args.name=${WANDB_RUN_NAME} \ |
| 128 | + +wandb_init_args.id=${WANDB_RUN_NAME} \ |
| 129 | + +wandb_init_args.project=${WANDB_PROJECT} \ |
| 130 | + checkpoint.save_final_model=${SAVE_FINAL_MODEL} \ |
| 131 | + checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \ |
| 132 | + checkpoint.ckpt_dir=/workspace/bionemo/checkpoints \ |
| 133 | + checkpoint.resume_from_checkpoint=${RESUME_FROM_CHECKPOINT} \ |
| 134 | + hydra.run.dir=${HYDRA_RUN_DIR} \ |
| 135 | + fp8_config.enabled=${FP8_ENABLED} \ |
| 136 | + fp8_config.fp8_recipe=${FP8_RECIPE} \ |
| 137 | + fp8_config.fp8_format=${FP8_FORMAT} \ |
| 138 | + +dataset.pad_to_multiple_of=32 |
| 139 | +
|
| 140 | +echo "=========================================" |
| 141 | +echo "Training complete!" |
| 142 | +echo "=========================================" |
| 143 | +OUTER_EOF |
| 144 | + |
| 145 | +# Inject environment variables into the command. |
| 146 | +COMMAND="export DIST_STRATEGY=\"${DIST_STRATEGY}\"; ${COMMAND}" |
| 147 | +COMMAND="export PRECISION_TAG=\"${PRECISION_TAG}\"; ${COMMAND}" |
| 148 | +COMMAND="export CLUSTER_NAME=\"${CLUSTER_NAME}\"; ${COMMAND}" |
| 149 | +COMMAND="export NPROC_PER_NODE=\"${NPROC_PER_NODE}\"; ${COMMAND}" |
| 150 | +COMMAND="export CONFIG_NAME=\"${CONFIG_NAME}\"; ${COMMAND}" |
| 151 | +COMMAND="export QUANT_STATS_ENABLED=\"${QUANT_STATS_ENABLED}\"; ${COMMAND}" |
| 152 | +COMMAND="export LOGGER_FREQUENCY=\"${LOGGER_FREQUENCY}\"; ${COMMAND}" |
| 153 | +COMMAND="export NUM_TRAIN_STEPS=\"${NUM_TRAIN_STEPS}\"; ${COMMAND}" |
| 154 | +COMMAND="export MICRO_BATCH_SIZE=\"${MICRO_BATCH_SIZE}\"; ${COMMAND}" |
| 155 | +COMMAND="export LEARNING_RATE=\"${LEARNING_RATE}\"; ${COMMAND}" |
| 156 | +COMMAND="export NUM_WORKERS=\"${NUM_WORKERS}\"; ${COMMAND}" |
| 157 | +COMMAND="export USE_SEQUENCE_PACKING=\"${USE_SEQUENCE_PACKING}\"; ${COMMAND}" |
| 158 | +COMMAND="export PRECISION=\"${PRECISION}\"; ${COMMAND}" |
| 159 | +COMMAND="export GRAD_REDUCE_TYPE=\"${GRAD_REDUCE_TYPE}\"; ${COMMAND}" |
| 160 | +COMMAND="export NUM_WARMUP_STEPS=\"${NUM_WARMUP_STEPS}\"; ${COMMAND}" |
| 161 | +COMMAND="export WANDB_RUN_NAME=\"${WANDB_RUN_NAME}\"; ${COMMAND}" |
| 162 | +COMMAND="export WANDB_PROJECT=\"${WANDB_PROJECT}\"; ${COMMAND}" |
| 163 | +COMMAND="export SAVE_FINAL_MODEL=\"${SAVE_FINAL_MODEL}\"; ${COMMAND}" |
| 164 | +COMMAND="export SAVE_EVERY_N_STEPS=\"${SAVE_EVERY_N_STEPS}\"; ${COMMAND}" |
| 165 | +COMMAND="export RESUME_FROM_CHECKPOINT=\"${RESUME_FROM_CHECKPOINT}\"; ${COMMAND}" |
| 166 | +COMMAND="export HYDRA_RUN_DIR=\"${HYDRA_RUN_DIR}\"; ${COMMAND}" |
| 167 | +COMMAND="export FP8_ENABLED=\"${FP8_ENABLED}\"; ${COMMAND}" |
| 168 | +COMMAND="export FP8_RECIPE=\"${FP8_RECIPE}\"; ${COMMAND}" |
| 169 | +COMMAND="export FP8_FORMAT=\"${FP8_FORMAT}\"; ${COMMAND}" |
| 170 | + |
| 171 | +COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}" |
| 172 | +COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" |
| 173 | +COMMAND="export HF_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}" |
| 174 | + |
| 175 | +echo "Launching: ${WANDB_RUN_NAME}" |
| 176 | + |
| 177 | +# AUTO-CHAIN: resubmit on timeout. |
| 178 | +trap ' |
| 179 | + rc=$? |
| 180 | + if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then |
| 181 | + echo "Killed by signal (rc=$rc) — assuming SLURM timeout, resubmitting..." |
| 182 | + sbatch --dependency=singleton "${BASH_SOURCE[0]}" |
| 183 | + elif [ "$rc" -eq 0 ]; then |
| 184 | + echo "Clean exit — training finished, NOT resubmitting." |
| 185 | + else |
| 186 | + echo "Error exit (rc=$rc) — NOT resubmitting; investigate ${RESULTS_DIR}" |
| 187 | + fi |
| 188 | + ' EXIT |
| 189 | + |
| 190 | +srun \ |
| 191 | + --output "${RESULTS_DIR}/slurm-%j-%n.out" \ |
| 192 | + --error "${RESULTS_DIR}/error-%j-%n.out" \ |
| 193 | + --container-image "${CONTAINER}" \ |
| 194 | + --container-mounts "${MOUNTS}" \ |
| 195 | + bash -c "${COMMAND}" |
0 commit comments