Skip to content

Commit febd9f5

Browse files
committed
make sure we initialize accelerator before model
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent a1e2340 commit febd9f5

2 files changed

Lines changed: 19 additions & 2 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
defaults:
2+
- defaults
3+
- _self_
4+
5+
model_tag: nvidia/esm2_t48_15B_UR50D
6+
stop_after_n_steps: 500
7+
trainer:
8+
run_name: "esm2_t48_15B_UR50D_perf"
9+
per_device_train_batch_size: 12
10+
per_device_eval_batch_size: 12
11+
report_to: "wandb"
12+
learning_rate: 1.6e-4
13+
weight_decay: 0.1
14+
warmup_steps: 20_000

recipes/esm2_accelerate/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import hydra
2020
import torch
2121
import transformers
22+
from accelerate import Accelerator
2223
from omegaconf import DictConfig
2324
from transformers import AutoConfig, AutoModelForMaskedLM
2425
from transformers.trainer import Trainer
@@ -35,6 +36,10 @@
3536
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
3637
def main(args: DictConfig):
3738
"""Entrypoint."""
39+
# We need to initialize the Accelerator manually prior to creating our model, otherwise we won't end up setting the
40+
# current torch device and the model creation will all happen on a single GPU, typically leading to an OOM.
41+
_ = Accelerator()
42+
3843
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
3944
config.max_seq_length = args.max_seq_length
4045
config.micro_batch_size = args.trainer.per_device_train_batch_size
@@ -57,8 +62,6 @@ def main(args: DictConfig):
5762
callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)],
5863
)
5964

60-
logger.info("ACCELERATE STATE:\n%s\n", trainer.accelerator.state)
61-
6265
if training_args.do_train:
6366
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
6467
last_checkpoint = transformers.trainer_utils.get_last_checkpoint(training_args.output_dir)

0 commit comments

Comments
 (0)