Skip to content

Commit 3c2e6ce

Browse files
authored
updating esm2 native recipe (#1078)
* Adds separate `train_ddp.py`, `train_fsdp2.py`, and `train_nvfsdp.py` entrypoints * Adds comparison against FA-2 based HF transformers model <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Distributed ESM‑2 training entrypoints (DDP, MFSDP/FSDP2) and a shared linear warmup/decay LR scheduler. * **Configuration** * Switched defaults to nvFSDP-style sharding; updated model identifiers and training hyperparameters (train steps, warmup, optimizer LR). * **Documentation** * Added new ESM‑2 training README; removed an outdated README. * **Build/Chores** * Install Transformers from the Git repo; Docker builds can use netrc credentials for installs. * **Tests** * Added extensive single‑ and multi‑GPU training tests; removed obsolete tests. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 7c18697 commit 3c2e6ce

25 files changed

Lines changed: 596 additions & 259 deletions

File tree

.devcontainer/recipes/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ megatron-fsdp==0.1.0rc0
77
torchmetrics
88
tqdm
99
transformer_engine
10-
transformers
10+
transformers @ git+https://github.com/huggingface/transformers.git
1111
typer
1212
wandb

recipes/README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ import torch
135135
from torch.distributed import init_process_group, destroy_process_group
136136

137137

138-
@hydra.main(
139-
config_path="hydra_config", config_name="L0_sanity.yaml", version_base="1.2"
140-
)
138+
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
141139
def main(args: DictConfig):
142140
"""Main training entrypoint."""
143141

@@ -306,7 +304,7 @@ def test_accelerate_launch(accelerate_config, tmp_path):
306304
str(accelerate_config_path),
307305
"train.py",
308306
"--config-name",
309-
"L0_sanity.yaml",
307+
"L0_sanity",
310308
f"trainer.output_dir={tmp_path}",
311309
],
312310
cwd=recipe_dir,

recipes/amplify_accelerate_te_fp8/test_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def test_accelerate_launch(accelerate_config, tmp_path):
187187
str(accelerate_config_path),
188188
str(train_py),
189189
"--config-name",
190-
"L0_sanity.yaml",
190+
"L0_sanity",
191191
f"trainer.output_dir={tmp_path}",
192192
],
193193
cwd=recipe_dir,

recipes/amplify_accelerate_te_fp8/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35-
@hydra.main(config_path="hydra_config", config_name="L0_sanity.yaml", version_base="1.2")
35+
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
3636
def main(args: DictConfig):
3737
"""Entrypoint."""
3838
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
File renamed without changes.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# syntax=docker/dockerfile:1.4
22
FROM nvcr.io/nvidia/pytorch:25.06-py3
33

4-
RUN --mount=type=cache,target=/root/.cache/pip \
4+
RUN --mount=type=secret,id=netrc,target=/root/.netrc \
5+
--mount=type=cache,target=/root/.cache/pip \
56
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
67
PIP_CONSTRAINT= pip install -r /requirements.txt
78

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# ESM-2 training with megatron-fsdp and custom pytorch training loop
2+
3+
Build the docker image with the following command:
4+
5+
```bash
6+
docker build -t my_image .
7+
```
8+
9+
## Running training
10+
11+
Run training with
12+
13+
```bash
14+
docker run --rm -it --gpus all my_image torchrun train_mfsdp.py --config-name L0_sanity
15+
```

recipes/esm2_native_te_nvfsdp/hydra_config/L0_sanity.yaml renamed to recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ defaults:
22
- defaults
33

44
# Training config
5-
model_name: esm2_t6_8M_UR50D
5+
model_name: nvidia/esm2_t6_8M_UR50D
66
micro_batch_size: 2
7-
num_train_steps: 5
7+
num_train_steps: 250
88

99
# WandB config
1010
wandb_init_args:
@@ -13,5 +13,7 @@ wandb_init_args:
1313

1414
# Learning rate scheduler config
1515
lr_scheduler_kwargs:
16-
num_warmup_steps: 2
17-
num_training_steps: 8
16+
num_warmup_steps: 0
17+
18+
adamw_kwargs:
19+
lr: 1e-2

0 commit comments

Comments
 (0)