Skip to content

Commit b53025e

Browse files
committed
make script run in multi-node
1 parent b1126aa commit b53025e

1 file changed

Lines changed: 31 additions & 8 deletions

File tree

  • bionemo-recipes/recipes/codonfm_native_te/slurm

bionemo-recipes/recipes/codonfm_native_te/slurm/1b.sh

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
#SBATCH --exclusive
1212
set -euxo pipefail
1313

14+
# Establish or inherit chain ID: manual launch picks SLURM_JOB_ID; trap-resubmit inherits via --export.
15+
if [ -z "${CHAIN_ID:-}" ]; then
16+
export CHAIN_ID="${SLURM_JOB_ID}"
17+
echo "Starting NEW chain: CHAIN_ID=${CHAIN_ID}"
18+
else
19+
echo "Continuing chain ${CHAIN_ID} (current job ${SLURM_JOB_ID})"
20+
fi
21+
1422
# ============================================================================
1523
# Codon 1B
1624
# ============================================================================
@@ -26,15 +34,15 @@ CODE_MOUNT="/workspace/bionemo"
2634
: "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}"
2735

2836
export GLOBAL_BATCH_SIZE=1536
29-
export MICRO_BATCH_SIZE=4
37+
export MICRO_BATCH_SIZE=96
3038

3139
# Experiment parameters
3240
export CONFIG_NAME=encodon_1b
3341
export NPROC_PER_NODE=8
3442
export DIST_STRATEGY=ddp # fsdp or ddp
3543

3644
# Training
37-
export NUM_TRAIN_STEPS=1000
45+
export NUM_TRAIN_STEPS=100
3846
export LEARNING_RATE=7.5e-5
3947
export NUM_WORKERS=1
4048
export USE_SEQUENCE_PACKING=False
@@ -89,7 +97,7 @@ fi
8997
export GRAD_ACC_STEPS=$(( GLOBAL_BATCH_SIZE / TOTAL_PER_STEP ))
9098
echo "Batch sizing: GBS=${GLOBAL_BATCH_SIZE}, MBS=${MICRO_BATCH_SIZE}, NPROC=${NPROC_PER_NODE}, NODES=${SLURM_JOB_NUM_NODES}, GRAD_ACC=${GRAD_ACC_STEPS}"
9199

92-
export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}"
100+
export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_gbs${GLOBAL_BATCH_SIZE}_mbs${MICRO_BATCH_SIZE}_ga${GRAD_ACC_STEPS}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}_chain_${CHAIN_ID}"
93101

94102
# Mounts
95103
RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}"
@@ -99,6 +107,10 @@ mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}"
99107

100108
MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints"
101109

110+
# Resolve head node on the host (scontrol is not available inside the container).
111+
MASTER_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1)
112+
MASTER_PORT=29500
113+
102114

103115
read -r -d '' COMMAND <<'OUTER_EOF' || true
104116
set -euxo pipefail
@@ -123,7 +135,14 @@ case "${DIST_STRATEGY}" in
123135
;;
124136
esac
125137
126-
torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \
138+
torchrun \
139+
--nproc_per_node=${NPROC_PER_NODE} \
140+
--rdzv_id=${SLURM_JOB_ID} \
141+
--rdzv_backend=c10d \
142+
--rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
143+
--nnodes=${SLURM_JOB_NUM_NODES} \
144+
--node-rank=${SLURM_NODEID} \
145+
${TRAIN_SCRIPT} \
127146
--config-name ${CONFIG_NAME} \
128147
quant_stats_config.enabled=${QUANT_STATS_ENABLED} \
129148
logger.frequency=${LOGGER_FREQUENCY} \
@@ -182,6 +201,8 @@ COMMAND="export HYDRA_RUN_DIR=\"${HYDRA_RUN_DIR}\"; ${COMMAND}"
182201
COMMAND="export FP8_ENABLED=\"${FP8_ENABLED}\"; ${COMMAND}"
183202
COMMAND="export FP8_RECIPE=\"${FP8_RECIPE}\"; ${COMMAND}"
184203
COMMAND="export FP8_FORMAT=\"${FP8_FORMAT}\"; ${COMMAND}"
204+
COMMAND="export MASTER_ADDR=\"${MASTER_ADDR}\"; ${COMMAND}"
205+
COMMAND="export MASTER_PORT=\"${MASTER_PORT}\"; ${COMMAND}"
185206

186207
COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}"
187208
COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}"
@@ -193,12 +214,14 @@ echo "Launching: ${WANDB_RUN_NAME}"
193214
trap '
194215
rc=$?
195216
if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then
196-
echo "Killed by signal (rc=$rc) — assuming SLURM timeout, resubmitting..."
197-
sbatch --dependency=singleton "${BASH_SOURCE[0]}"
217+
echo "Timed out (rc=$rc) — resubmitting chain ${CHAIN_ID}."
218+
sbatch --dependency=singleton \
219+
--export=ALL,CHAIN_ID="${CHAIN_ID}" \
220+
"${BASH_SOURCE[0]}"
198221
elif [ "$rc" -eq 0 ]; then
199-
echo "Clean exit — training finished, NOT resubmitting."
222+
echo "Training finished cleanly — chain ${CHAIN_ID} ends."
200223
else
201-
echo "Error exit (rc=$rc) — NOT resubmitting; investigate ${RESULTS_DIR}"
224+
echo "Real error (rc=$rc) — chain ${CHAIN_ID} ends so you can investigate."
202225
fi
203226
' EXIT
204227

0 commit comments

Comments
 (0)