Skip to content

Commit 1db0a0c

Browse files
committed
Add train ddp to recipe
1 parent e870d31 commit 1db0a0c

4 files changed

Lines changed: 453 additions & 1 deletion

File tree

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# syntax=docker/dockerfile:1.4
2-
FROM nvcr.io/nvidia/pytorch:26.04-py3
2+
FROM nvcr.io/nvidia/pytorch:26.02-py3
3+
4+
RUN apt-get update && apt-get install -y tmux npm
35

46
RUN --mount=type=cache,target=/root/.cache/pip \
57
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
68
PIP_CONSTRAINT= pip install -r /requirements.txt
79

10+
RUN curl -fsSL https://claude.ai/install.sh | bash # Install Claude CLI tool
11+
RUN npm install -g @openai/codex
12+
813
WORKDIR /workspace/bionemo
914
COPY . .

bionemo-recipes/recipes/codonfm_native_te/checkpoint.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,91 @@ def save_final_model_fsdp2(
221221
save_file(model_state_dict, os.path.join(save_directory, "model.safetensors"))
222222
config.to_json_file(os.path.join(save_directory, "config.json"))
223223
logger.info(f"Saved final FSDP2 model to {save_directory}")
224+
225+
226+
# ============================================================================
227+
# DDP Checkpointing
228+
# ============================================================================
229+
230+
231+
def load_checkpoint_ddp(
232+
model: torch.nn.Module,
233+
optimizer: torch.optim.Optimizer,
234+
scheduler: torch.optim.lr_scheduler.LRScheduler,
235+
ckpt_path: str | os.PathLike,
236+
dist_config: DistributedConfig,
237+
) -> CheckpointOutput:
238+
"""Load DDP checkpoint."""
239+
checkpoint_path, _ = get_latest_checkpoint(ckpt_path)
240+
if not checkpoint_path:
241+
logger.info("No DDP checkpoint found, starting from scratch")
242+
return CheckpointOutput(model, optimizer, scheduler, 0, 0)
243+
244+
checkpoint = torch.load(
245+
checkpoint_path / "checkpoint.pt",
246+
map_location=f"cuda:{dist_config.local_rank}",
247+
weights_only=True,
248+
)
249+
250+
model.load_state_dict(checkpoint["model"], strict=False)
251+
optimizer.load_state_dict(checkpoint["optimizer"])
252+
scheduler.load_state_dict(checkpoint["scheduler"])
253+
254+
if dist_config.is_main_process():
255+
logger.info(f"Loaded DDP checkpoint from step {checkpoint['step']}")
256+
257+
# Increment the step by one to avoid re-running the previous step.
258+
return CheckpointOutput(model, optimizer, scheduler, checkpoint["step"] + 1, checkpoint["epoch"])
259+
260+
261+
def save_checkpoint_ddp(
262+
model: torch.nn.Module,
263+
optimizer: torch.optim.Optimizer,
264+
scheduler: torch.optim.lr_scheduler.LRScheduler,
265+
ckpt_path: str | os.PathLike,
266+
step: int,
267+
epoch: int,
268+
dist_config: DistributedConfig,
269+
max_checkpoints: int | None = None,
270+
) -> None:
271+
"""Save DDP checkpoint (rank-0 only since the model is replicated)."""
272+
if not dist_config.is_main_process():
273+
return
274+
275+
ckpt_path = Path(ckpt_path)
276+
checkpoint_path = ckpt_path / f"step_{step}"
277+
checkpoint_path.mkdir(parents=True, exist_ok=True)
278+
279+
torch.save(
280+
{
281+
"model": model.state_dict(),
282+
"optimizer": optimizer.state_dict(),
283+
"scheduler": scheduler.state_dict(),
284+
"step": step,
285+
"epoch": epoch,
286+
},
287+
checkpoint_path / "checkpoint.pt",
288+
)
289+
logger.info(f"Saved DDP checkpoint to {checkpoint_path}")
290+
291+
if max_checkpoints is not None:
292+
prune_checkpoints(ckpt_path, max_checkpoints)
293+
294+
295+
def save_final_model_ddp(
296+
model: torch.nn.Module,
297+
config,
298+
save_directory: str | os.PathLike,
299+
dist_config: DistributedConfig,
300+
) -> None:
301+
"""Save final model for DDP - only on main process."""
302+
if not dist_config.is_main_process():
303+
return
304+
305+
# Unwrap DDP if wrapped.
306+
underlying_model = model.module if hasattr(model, "module") else model
307+
308+
os.makedirs(save_directory, exist_ok=True)
309+
save_file(underlying_model.state_dict(), os.path.join(save_directory, "model.safetensors"))
310+
config.to_json_file(os.path.join(save_directory, "config.json"))
311+
logger.info(f"Saved final DDP model to {save_directory}")
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
export CPATH=/usr/local/cuda/include
5+
export TRITON_PTXAS_PATH=/usr/local/cuda/bin/ptxas
6+
7+
# Run config
8+
export CONFIG_NAME=encodon_1b
9+
export NPROC_PER_NODE=8
10+
export DIST_STRATEGY=ddp # fsdp or ddp
11+
12+
# Training
13+
export NUM_TRAIN_STEPS=100
14+
export MICRO_BATCH_SIZE=31
15+
export NUM_WORKERS=1
16+
export USE_SEQUENCE_PACKING=True
17+
export USE_FP32_MASTER_WEIGHTS=True
18+
export NUM_WARMUP_STEPS=500
19+
20+
# Logging / W&B
21+
export LOGGER_FREQUENCY=10
22+
export WANDB_API_KEY=""
23+
export WANDB_PROJECT=codon-fm-low-precision
24+
25+
# Checkpointing
26+
export SAVE_FINAL_MODEL=False
27+
export SAVE_EVERY_N_STEPS=100000
28+
export CKPT_DIR=/tmp
29+
export RESUME_FROM_CHECKPOINT=False
30+
31+
# Hydra
32+
export HYDRA_RUN_DIR=1b_test
33+
34+
# Quantization / FP8
35+
export QUANT_STATS_ENABLED=False
36+
export FP8_ENABLED=True
37+
export FP8_RECIPE=transformer_engine.common.recipe.MXFP8BlockScaling
38+
export FP8_FORMAT=E4M3
39+
40+
# Data
41+
export DATASET_DATA_PATH=/data/balvisio/codonfm/reference-dataset/codonfm/processed_unfiltered/
42+
43+
# Derived: build wandb run name from model size, batch size, and precision recipe
44+
MODEL_SIZE="${CONFIG_NAME##*_}"
45+
if [ "${FP8_ENABLED}" = "True" ]; then
46+
RECIPE_SHORT="${FP8_RECIPE##*.}"
47+
RECIPE_SHORT="${RECIPE_SHORT%BlockScaling}"
48+
RECIPE_SHORT="${RECIPE_SHORT%Scaling}"
49+
PRECISION_TAG="${RECIPE_SHORT,,}_${FP8_FORMAT,,}"
50+
else
51+
PRECISION_TAG="bf16"
52+
fi
53+
export WANDB_RUN_NAME="${MODEL_SIZE}_${DIST_STRATEGY}_bs${MICRO_BATCH_SIZE}_${PRECISION_TAG}"
54+
55+
# Pick training script based on distributed strategy.
56+
# DDP can't emulate FSDP's fp32-master / bf16-param split, so force fp32 master weights off.
57+
case "${DIST_STRATEGY}" in
58+
fsdp)
59+
TRAIN_SCRIPT=train_fsdp2.py
60+
;;
61+
ddp)
62+
TRAIN_SCRIPT=train_ddp.py
63+
if [ "${USE_FP32_MASTER_WEIGHTS}" = "True" ]; then
64+
echo "DIST_STRATEGY=ddp: overriding USE_FP32_MASTER_WEIGHTS=True -> False" >&2
65+
export USE_FP32_MASTER_WEIGHTS=False
66+
fi
67+
;;
68+
*)
69+
echo "DIST_STRATEGY must be 'fsdp' or 'ddp', got '${DIST_STRATEGY}'" >&2
70+
exit 1
71+
;;
72+
esac
73+
74+
torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \
75+
--config-name ${CONFIG_NAME} \
76+
quant_stats_config.enabled=${QUANT_STATS_ENABLED} \
77+
logger.frequency=${LOGGER_FREQUENCY} \
78+
num_train_steps=${NUM_TRAIN_STEPS} \
79+
dataset.micro_batch_size=${MICRO_BATCH_SIZE} \
80+
dataset.num_workers=${NUM_WORKERS} \
81+
dataset.data_path=${DATASET_DATA_PATH} \
82+
use_sequence_packing=${USE_SEQUENCE_PACKING} \
83+
use_fp32_master_weights=${USE_FP32_MASTER_WEIGHTS} \
84+
lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \
85+
wandb_init_args.name=${WANDB_RUN_NAME} \
86+
+wandb_init_args.project=${WANDB_PROJECT} \
87+
checkpoint.save_final_model=${SAVE_FINAL_MODEL} \
88+
checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \
89+
checkpoint.ckpt_dir=${CKPT_DIR} \
90+
checkpoint.resume_from_checkpoint=${RESUME_FROM_CHECKPOINT} \
91+
hydra.run.dir=${HYDRA_RUN_DIR} \
92+
fp8_config.enabled=${FP8_ENABLED} \
93+
fp8_config.fp8_recipe=${FP8_RECIPE} \
94+
fp8_config.fp8_format=${FP8_FORMAT}

0 commit comments

Comments
 (0)