1111# SBATCH --exclusive
1212set -euxo pipefail
1313
14+ # Establish or inherit chain ID: manual launch picks SLURM_JOB_ID; trap-resubmit inherits via --export.
15+ if [ -z " ${CHAIN_ID:- } " ]; then
16+ export CHAIN_ID=" ${SLURM_JOB_ID} "
17+ echo " Starting NEW chain: CHAIN_ID=${CHAIN_ID} "
18+ else
19+ echo " Continuing chain ${CHAIN_ID} (current job ${SLURM_JOB_ID} )"
20+ fi
21+
1422# ============================================================================
1523# Codon 1B
1624# ============================================================================
@@ -26,15 +34,15 @@ CODE_MOUNT="/workspace/bionemo"
2634: " ${CLUSTER_NAME:? Set CLUSTER_NAME in ~/ .bash_profile} "
2735
2836export GLOBAL_BATCH_SIZE=1536
29- export MICRO_BATCH_SIZE=4
37+ export MICRO_BATCH_SIZE=96
3038
3139# Experiment parameters
3240export CONFIG_NAME=encodon_1b
3341export NPROC_PER_NODE=8
3442export DIST_STRATEGY=ddp # fsdp or ddp
3543
3644# Training
37- export NUM_TRAIN_STEPS=1000
45+ export NUM_TRAIN_STEPS=100
3846export LEARNING_RATE=7.5e-5
3947export NUM_WORKERS=1
4048export USE_SEQUENCE_PACKING=False
8997export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP ))
9098echo " Batch sizing: GBS=${GLOBAL_BATCH_SIZE} , MBS=${MICRO_BATCH_SIZE} , NPROC=${NPROC_PER_NODE} , NODES=${SLURM_JOB_NUM_NODES} , GRAD_ACC=${GRAD_ACC_STEPS} "
9199
92- export WANDB_RUN_NAME=" ${MODEL_SIZE} _${DIST_STRATEGY} _${BATCH_TYPE_TAG} _gbs${GLOBAL_BATCH_SIZE} _mbs${MICRO_BATCH_SIZE} _ga${GRAD_ACC_STEPS} _${PRECISION_TAG} _nodes_${SLURM_JOB_NUM_NODES} _${CLUSTER_NAME} "
100+ export WANDB_RUN_NAME=" ${MODEL_SIZE} _${DIST_STRATEGY} _${BATCH_TYPE_TAG} _gbs${GLOBAL_BATCH_SIZE} _mbs${MICRO_BATCH_SIZE} _ga${GRAD_ACC_STEPS} _${PRECISION_TAG} _nodes_${SLURM_JOB_NUM_NODES} _${CLUSTER_NAME} _chain_ ${CHAIN_ID} "
93101
94102# Mounts
95103RESULTS_DIR=" ${BASE_DIR} /results/${WANDB_RUN_NAME} "
@@ -99,6 +107,10 @@ mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}"
99107
100108MOUNTS=" ${DATA_DIR} :${CODE_MOUNT} /data,${RESULTS_DIR} :${CODE_MOUNT} /results,${CKPT_DIR} :${CODE_MOUNT} /checkpoints"
101109
110+ # Resolve head node on the host (scontrol is not available inside the container).
111+ MASTER_ADDR=$( scontrol show hostnames " ${SLURM_JOB_NODELIST} " | head -n 1)
112+ MASTER_PORT=29500
113+
102114
103115read -r -d ' ' COMMAND << 'OUTER_EOF ' || true
104116set -euxo pipefail
@@ -123,7 +135,14 @@ case "${DIST_STRATEGY}" in
123135 ;;
124136esac
125137
126- torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \
138+ torchrun \
139+ --nproc_per_node=${NPROC_PER_NODE} \
140+ --rdzv_id=${SLURM_JOB_ID} \
141+ --rdzv_backend=c10d \
142+ --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
143+ --nnodes=${SLURM_JOB_NUM_NODES} \
144+ --node-rank=${SLURM_NODEID} \
145+ ${TRAIN_SCRIPT} \
127146 --config-name ${CONFIG_NAME} \
128147 quant_stats_config.enabled=${QUANT_STATS_ENABLED} \
129148 logger.frequency=${LOGGER_FREQUENCY} \
@@ -182,6 +201,8 @@ COMMAND="export HYDRA_RUN_DIR=\"${HYDRA_RUN_DIR}\"; ${COMMAND}"
182201COMMAND=" export FP8_ENABLED=\" ${FP8_ENABLED} \" ; ${COMMAND} "
183202COMMAND=" export FP8_RECIPE=\" ${FP8_RECIPE} \" ; ${COMMAND} "
184203COMMAND=" export FP8_FORMAT=\" ${FP8_FORMAT} \" ; ${COMMAND} "
204+ COMMAND=" export MASTER_ADDR=\" ${MASTER_ADDR} \" ; ${COMMAND} "
205+ COMMAND=" export MASTER_PORT=\" ${MASTER_PORT} \" ; ${COMMAND} "
185206
186207COMMAND=" export WANDB_API_KEY=\" ${WANDB_API_KEY} \" ; ${COMMAND} "
187208COMMAND=" export HUGGING_FACE_HUB_TOKEN=\" ${HUGGING_FACE_HUB_TOKEN} \" ; ${COMMAND} "
@@ -193,12 +214,14 @@ echo "Launching: ${WANDB_RUN_NAME}"
193214trap '
194215 rc=$?
195216 if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then
196- echo "Killed by signal (rc=$rc) — assuming SLURM timeout, resubmitting..."
197- sbatch --dependency=singleton "${BASH_SOURCE[0]}"
217+ echo "Timed out (rc=$rc) — resubmitting chain ${CHAIN_ID}."
218+ sbatch --dependency=singleton \
219+ --export=ALL,CHAIN_ID="${CHAIN_ID}" \
220+ "${BASH_SOURCE[0]}"
198221 elif [ "$rc" -eq 0 ]; then
199- echo "Clean exit — training finished, NOT resubmitting ."
222+ echo "Training finished cleanly — chain ${CHAIN_ID} ends ."
200223 else
201- echo "Error exit (rc=$rc) — NOT resubmitting; investigate ${RESULTS_DIR} "
224+ echo "Real error (rc=$rc) — chain ${CHAIN_ID} ends so you can investigate. "
202225 fi
203226 ' EXIT
204227
0 commit comments