Skip to content

Commit aab4267

Browse files
committed
Added support for different precision modes
1 parent 33dd223 commit aab4267

9 files changed

Lines changed: 325 additions & 30 deletions

File tree

bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ defaults:
66
model_preset: encodon_200k
77
num_train_steps: 250
88

9+
precision: fp32
10+
911
use_sequence_packing: false
1012
dataset:
1113
data_path: train.parquet

bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,15 @@ quant_stats_config:
8181
# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
8282
fp8_layers: null
8383
fp4_layers: null
84-
use_fp32_master_weights: null
84+
85+
# Precision mode. One of:
86+
# fp32 - params, compute, grads, and optimizer state all in fp32.
87+
# bf16 - params, compute, grads, and optimizer state all in bf16 (pure bf16).
88+
# bf16-mixed - fp32 master weights + bf16 compute (via autocast in DDP, via FSDP2
89+
# MixedPrecisionPolicy.param_dtype=bf16 in FSDP2).
90+
precision: ???
91+
92+
# Gradient reduce dtype for FSDP2 when precision=bf16-mixed. One of: fp32, bf16.
93+
# fp32 (default) is more conservative than PTL FSDP bf16-mixed (which reduces in bf16).
94+
# Ignored for other precision modes and for DDP.
95+
grad_reduce_type: fp32

bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ defaults:
66
model_preset: encodon_1b
77
num_train_steps: 500_000
88

9+
precision: bf16-mixed
10+
911
use_sequence_packing: true
1012
dataset:
1113
data_path: ???

bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ defaults:
66
model_preset: encodon_5b
77
num_train_steps: 500_000
88

9+
precision: bf16-mixed
10+
911
use_sequence_packing: true
1012
dataset:
1113
data_path: ???

bionemo-recipes/recipes/codonfm_native_te/run_1b.sh

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ export NUM_TRAIN_STEPS=100
1414
export MICRO_BATCH_SIZE=31
1515
export NUM_WORKERS=1
1616
export USE_SEQUENCE_PACKING=True
17-
export USE_FP32_MASTER_WEIGHTS=True
17+
# Precision mode: one of fp32, bf16, bf16-mixed. bf16-mixed matches the reference codonfm `--bf16`.
18+
export PRECISION=bf16-mixed
19+
# Only used for FSDP2 + bf16-mixed. One of fp32, bf16.
20+
export GRAD_REDUCE_TYPE=fp32
1821
export NUM_WARMUP_STEPS=500
1922

2023
# Logging / W&B
@@ -46,24 +49,19 @@ if [ "${FP8_ENABLED}" = "True" ]; then
4649
RECIPE_SHORT="${FP8_RECIPE##*.}"
4750
RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}"
4851
RECIPE_SHORT="${RECIPE_SHORT%Scaling}"
49-
PRECISION_TAG="${RECIPE_SHORT,,}_${FP8_FORMAT,,}"
52+
PRECISION_TAG="${PRECISION}_${RECIPE_SHORT,,}_${FP8_FORMAT,,}"
5053
else
51-
PRECISION_TAG="bf16"
54+
PRECISION_TAG="${PRECISION}"
5255
fi
5356
export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}"
5457

5558
# Pick training script based on distributed strategy.
56-
# DDP can't emulate FSDP's fp32-master / bf16-param split, so force fp32 master weights off.
5759
case "${DIST_STRATEGY}" in
5860
fsdp)
5961
TRAIN_SCRIPT=train_fsdp2.py
6062
;;
6163
ddp)
6264
TRAIN_SCRIPT=train_ddp.py
63-
if [ "${USE_FP32_MASTER_WEIGHTS}" = "True" ]; then
64-
echo "DIST_STRATEGY=ddp: overriding USE_FP32_MASTER_WEIGHTS=True -> False" >&2
65-
export USE_FP32_MASTER_WEIGHTS=False
66-
fi
6765
;;
6866
*)
6967
echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2
@@ -80,7 +78,8 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \
8078
dataset.num_workers=${NUM_WORKERS} \
8179
dataset.data_path=${DATASET_DATA_PATH} \
8280
use_sequence_packing=${USE_SEQUENCE_PACKING} \
83-
use_fp32_master_weights=${USE_FP32_MASTER_WEIGHTS} \
81+
precision=${PRECISION} \
82+
grad_reduce_type=${GRAD_REDUCE_TYPE} \
8483
lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \
8584
wandb_init_args.name=${WANDB_RUN_NAME} \
8685
wandb_init_args.project=${WANDB_PROJECT} \
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#!/bin/bash
2+
#SBATCH --account=
3+
#SBATCH --nodes=1
4+
#SBATCH --partition=
5+
#SBATCH --ntasks-per-node=1
6+
#SBATCH --time=03:55:00
7+
#SBATCH --mem=0
8+
#SBATCH --job-name=
9+
#SBATCH --mail-type=FAIL
10+
#SBATCH --overcommit
11+
#SBATCH --exclusive
12+
set -euxo pipefail
13+
14+
# ============================================================================
15+
# Codon 1B
16+
# ============================================================================
17+
18+
BASE_DIR=""
19+
CONTAINER=""
20+
DATA_DIR="${BASE_DIR}/data"
21+
CODE_MOUNT="/workspace/bionemo"
22+
23+
24+
: "${WANDB_API_KEY:?Set WANDB_API_KEY in ~/.bash_profile}"
25+
: "${HUGGING_FACE_HUB_TOKEN:?Set HUGGING_FACE_HUB_TOKEN in ~/.bash_profile}"
26+
: "${CLUSTER_NAME:?Set CLUSTER_NAME in ~/.bash_profile}"
27+
28+
# Experiment parameters
29+
export CONFIG_NAME=encodon_1b
30+
export NPROC_PER_NODE=8
31+
export DIST_STRATEGY=ddp # fsdp or ddp
32+
33+
# Training
34+
export NUM_TRAIN_STEPS=1000
35+
export MICRO_BATCH_SIZE=31
36+
export LEARNING_RATE=7.5e-5
37+
export NUM_WORKERS=1
38+
export USE_SEQUENCE_PACKING=False
39+
# Precision mode: one of fp32, bf16, bf16-mixed. bf16-mixed matches the reference codonfm `--bf16`.
40+
export PRECISION=bf16-mixed
41+
# Only used for FSDP2 + bf16-mixed. One of fp32, bf16.
42+
export GRAD_REDUCE_TYPE=fp32
43+
export NUM_WARMUP_STEPS=50
44+
45+
# Logging / W&B
46+
export LOGGER_FREQUENCY=10
47+
export WANDB_PROJECT=
48+
49+
# Checkpointing
50+
export SAVE_FINAL_MODEL=True
51+
export SAVE_EVERY_N_STEPS=100000
52+
export RESUME_FROM_CHECKPOINT=True
53+
54+
# Hydra
55+
export HYDRA_RUN_DIR=1b_test
56+
57+
# Quantization / FP8
58+
export QUANT_STATS_ENABLED=False
59+
export FP8_ENABLED=False
60+
export FP8_RECIPE=transformer_engine.common.recipe.MXFP8BlockScaling
61+
export FP8_FORMAT=E4M3
62+
63+
# Derived: build wandb run name from model size, batch size, and precision recipe
64+
MODEL_SIZE="${CONFIG_NAME##*_}"
65+
if [ "${FP8_ENABLED}" = "True" ]; then
66+
RECIPE_SHORT="${FP8_RECIPE##*.}"
67+
RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}"
68+
RECIPE_SHORT="${RECIPE_SHORT%Scaling}"
69+
PRECISION_TAG="${PRECISION}_${RECIPE_SHORT,,}_${FP8_FORMAT,,}"
70+
else
71+
PRECISION_TAG="${PRECISION}"
72+
fi
73+
74+
if [ "${USE_SEQUENCE_PACKING}" = "True" ]; then
75+
BATCH_TYPE_TAG="thd"
76+
else
77+
BATCH_TYPE_TAG="bshd"
78+
fi
79+
80+
export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_${BATCH_TYPE_TAG}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}_nodes_${SLURM_JOB_NUM_NODES}_${CLUSTER_NAME}"
81+
82+
# Mounts
83+
RESULTS_DIR="${BASE_DIR}/results/${WANDB_RUN_NAME}"
84+
CKPT_DIR="${BASE_DIR}/checkpoints/${WANDB_RUN_NAME}"
85+
86+
mkdir -p "${RESULTS_DIR}" "${CKPT_DIR}"
87+
88+
MOUNTS="${DATA_DIR}:${CODE_MOUNT}/data,${RESULTS_DIR}:${CODE_MOUNT}/results,${CKPT_DIR}:${CODE_MOUNT}/checkpoints"
89+
90+
91+
read -r -d '' COMMAND <<'OUTER_EOF' || true
92+
set -euxo pipefail
93+
94+
echo "========================================="
95+
echo "CodonFM ${CONFIG_NAME} - STRATEGY: ${DIST_STRATEGY} - PRECISION: ${PRECISION_TAG} - CLUSTER: ${CLUSTER_NAME}"
96+
echo "Job ID: ${SLURM_JOB_ID}"
97+
echo "Nodes: ${SLURM_JOB_NUM_NODES}"
98+
echo "========================================="
99+
100+
# Pick training script based on distributed strategy.
101+
case "${DIST_STRATEGY}" in
102+
fsdp)
103+
TRAIN_SCRIPT=train_fsdp2.py
104+
;;
105+
ddp)
106+
TRAIN_SCRIPT=train_ddp.py
107+
;;
108+
*)
109+
echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2
110+
exit 1
111+
;;
112+
esac
113+
114+
torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \
115+
--config-name ${CONFIG_NAME} \
116+
quant_stats_config.enabled=${QUANT_STATS_ENABLED} \
117+
logger.frequency=${LOGGER_FREQUENCY} \
118+
num_train_steps=${NUM_TRAIN_STEPS} \
119+
dataset.micro_batch_size=${MICRO_BATCH_SIZE} \
120+
adamw_kwargs.lr=${LEARNING_RATE} \
121+
dataset.num_workers=${NUM_WORKERS} \
122+
dataset.data_path=/workspace/bionemo/data/processed_unfiltered/ \
123+
use_sequence_packing=${USE_SEQUENCE_PACKING} \
124+
precision=${PRECISION} \
125+
grad_reduce_type=${GRAD_REDUCE_TYPE} \
126+
lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \
127+
wandb_init_args.name=${WANDB_RUN_NAME} \
128+
+wandb_init_args.id=${WANDB_RUN_NAME} \
129+
+wandb_init_args.project=${WANDB_PROJECT} \
130+
checkpoint.save_final_model=${SAVE_FINAL_MODEL} \
131+
checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \
132+
checkpoint.ckpt_dir=/workspace/bionemo/checkpoints \
133+
checkpoint.resume_from_checkpoint=${RESUME_FROM_CHECKPOINT} \
134+
hydra.run.dir=${HYDRA_RUN_DIR} \
135+
fp8_config.enabled=${FP8_ENABLED} \
136+
fp8_config.fp8_recipe=${FP8_RECIPE} \
137+
fp8_config.fp8_format=${FP8_FORMAT} \
138+
+dataset.pad_to_multiple_of=32
139+
140+
echo "========================================="
141+
echo "Training complete!"
142+
echo "========================================="
143+
OUTER_EOF
144+
145+
# Inject environment variables into the command.
146+
COMMAND="export DIST_STRATEGY=\"${DIST_STRATEGY}\"; ${COMMAND}"
147+
COMMAND="export PRECISION_TAG=\"${PRECISION_TAG}\"; ${COMMAND}"
148+
COMMAND="export CLUSTER_NAME=\"${CLUSTER_NAME}\"; ${COMMAND}"
149+
COMMAND="export NPROC_PER_NODE=\"${NPROC_PER_NODE}\"; ${COMMAND}"
150+
COMMAND="export CONFIG_NAME=\"${CONFIG_NAME}\"; ${COMMAND}"
151+
COMMAND="export QUANT_STATS_ENABLED=\"${QUANT_STATS_ENABLED}\"; ${COMMAND}"
152+
COMMAND="export LOGGER_FREQUENCY=\"${LOGGER_FREQUENCY}\"; ${COMMAND}"
153+
COMMAND="export NUM_TRAIN_STEPS=\"${NUM_TRAIN_STEPS}\"; ${COMMAND}"
154+
COMMAND="export MICRO_BATCH_SIZE=\"${MICRO_BATCH_SIZE}\"; ${COMMAND}"
155+
COMMAND="export LEARNING_RATE=\"${LEARNING_RATE}\"; ${COMMAND}"
156+
COMMAND="export NUM_WORKERS=\"${NUM_WORKERS}\"; ${COMMAND}"
157+
COMMAND="export USE_SEQUENCE_PACKING=\"${USE_SEQUENCE_PACKING}\"; ${COMMAND}"
158+
COMMAND="export PRECISION=\"${PRECISION}\"; ${COMMAND}"
159+
COMMAND="export GRAD_REDUCE_TYPE=\"${GRAD_REDUCE_TYPE}\"; ${COMMAND}"
160+
COMMAND="export NUM_WARMUP_STEPS=\"${NUM_WARMUP_STEPS}\"; ${COMMAND}"
161+
COMMAND="export WANDB_RUN_NAME=\"${WANDB_RUN_NAME}\"; ${COMMAND}"
162+
COMMAND="export WANDB_PROJECT=\"${WANDB_PROJECT}\"; ${COMMAND}"
163+
COMMAND="export SAVE_FINAL_MODEL=\"${SAVE_FINAL_MODEL}\"; ${COMMAND}"
164+
COMMAND="export SAVE_EVERY_N_STEPS=\"${SAVE_EVERY_N_STEPS}\"; ${COMMAND}"
165+
COMMAND="export RESUME_FROM_CHECKPOINT=\"${RESUME_FROM_CHECKPOINT}\"; ${COMMAND}"
166+
COMMAND="export HYDRA_RUN_DIR=\"${HYDRA_RUN_DIR}\"; ${COMMAND}"
167+
COMMAND="export FP8_ENABLED=\"${FP8_ENABLED}\"; ${COMMAND}"
168+
COMMAND="export FP8_RECIPE=\"${FP8_RECIPE}\"; ${COMMAND}"
169+
COMMAND="export FP8_FORMAT=\"${FP8_FORMAT}\"; ${COMMAND}"
170+
171+
COMMAND="export WANDB_API_KEY=\"${WANDB_API_KEY}\"; ${COMMAND}"
172+
COMMAND="export HUGGING_FACE_HUB_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}"
173+
COMMAND="export HF_TOKEN=\"${HUGGING_FACE_HUB_TOKEN}\"; ${COMMAND}"
174+
175+
echo "Launching: ${WANDB_RUN_NAME}"
176+
177+
# AUTO-CHAIN: resubmit on timeout.
178+
trap '
179+
rc=$?
180+
if [ "$rc" -eq 143 ] || [ "$rc" -eq 137 ]; then
181+
echo "Killed by signal (rc=$rc) — assuming SLURM timeout, resubmitting..."
182+
sbatch --dependency=singleton "${BASH_SOURCE[0]}"
183+
elif [ "$rc" -eq 0 ]; then
184+
echo "Clean exit — training finished, NOT resubmitting."
185+
else
186+
echo "Error exit (rc=$rc) — NOT resubmitting; investigate ${RESULTS_DIR}"
187+
fi
188+
' EXIT
189+
190+
srun \
191+
--output "${RESULTS_DIR}/slurm-%j-%n.out" \
192+
--error "${RESULTS_DIR}/error-%j-%n.out" \
193+
--container-image "${CONTAINER}" \
194+
--container-mounts "${MOUNTS}" \
195+
bash -c "${COMMAND}"

bionemo-recipes/recipes/codonfm_native_te/tests/test_train.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,31 @@ def test_sanity_convergence_fsdp2_fp8(tmp_path, recipe_path):
130130
assert final_loss < 5.0, f"Final loss {final_loss} is too high"
131131

132132

133-
def test_sanity_convergence_fsdp2_fp32_master_weights(tmp_path, recipe_path):
134-
"""Test CodonFM with FP32 master weights."""
133+
def test_sanity_convergence_fsdp2_bf16_mixed(tmp_path, recipe_path):
134+
"""Test CodonFM with bf16-mixed precision (fp32 master weights + bf16 compute)."""
135135
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
136136
sanity_config = compose(
137137
config_name="L0_sanity",
138138
overrides=[
139139
f"+wandb_init_args.dir={tmp_path}",
140140
f"checkpoint.ckpt_dir={tmp_path}",
141-
"use_fp32_master_weights=true",
141+
"precision=bf16-mixed",
142+
],
143+
)
144+
145+
final_loss = main_fsdp2(sanity_config)
146+
assert final_loss < 5.0, f"Final loss {final_loss} is too high"
147+
148+
149+
def test_sanity_convergence_fsdp2_bf16(tmp_path, recipe_path):
150+
"""Test CodonFM with pure bf16 precision."""
151+
with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
152+
sanity_config = compose(
153+
config_name="L0_sanity",
154+
overrides=[
155+
f"+wandb_init_args.dir={tmp_path}",
156+
f"checkpoint.ckpt_dir={tmp_path}",
157+
"precision=bf16",
142158
],
143159
)
144160

0 commit comments

Comments
 (0)