Skip to content

Commit fca6ead

Browse files
authored
make sure we initialize accelerator before model (#1132)
We need to initialize the `Accelerator` object before creating TE layers or they all end up on a single device <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added a ready-to-run performance test preset for the esm2 t48 15B model with sensible defaults: model tag, step cap, batch sizes, learning rate, weight decay, warmup steps, and Weights & Biases logging. - Bug Fixes - Improved multi-GPU initialization by starting distributed state earlier, reducing setup issues and OOM risk without changing training behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 0d30652 commit fca6ead

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
defaults:
2+
- defaults
3+
- _self_
4+
5+
model_tag: nvidia/esm2_t48_15B_UR50D
6+
stop_after_n_steps: 500
7+
trainer:
8+
run_name: "esm2_t48_15B_UR50D_perf"
9+
per_device_train_batch_size: 12
10+
per_device_eval_batch_size: 12
11+
report_to: "wandb"
12+
learning_rate: 1.6e-4
13+
weight_decay: 0.1
14+
warmup_steps: 20_000

recipes/esm2_accelerate/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import hydra
2020
import torch
2121
import transformers
22+
from accelerate import PartialState
2223
from omegaconf import DictConfig
2324
from transformers import AutoConfig, AutoModelForMaskedLM
2425
from transformers.trainer import Trainer
@@ -35,6 +36,15 @@
3536
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
3637
def main(args: DictConfig):
3738
"""Entrypoint."""
39+
# Initialize Accelerate's distributed state early so torch device is set per process
40+
state = PartialState()
41+
logger.info(
42+
"Accelerate initialized (local_process_index=%s, num_processes=%s, device=%s)",
43+
state.local_process_index,
44+
state.num_processes,
45+
state.device,
46+
)
47+
3848
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
3949
config.max_seq_length = args.max_seq_length
4050
config.micro_batch_size = args.trainer.per_device_train_batch_size
@@ -57,8 +67,6 @@ def main(args: DictConfig):
5767
callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)],
5868
)
5969

60-
logger.info("ACCELERATE STATE:\n%s\n", trainer.accelerator.state)
61-
6270
if training_args.do_train:
6371
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
6472
last_checkpoint = transformers.trainer_utils.get_last_checkpoint(training_args.output_dir)

0 commit comments

Comments
 (0)