Skip to content

Commit 710e405

Browse files
committed
use PartialState instead
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent febd9f5 commit 710e405

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

recipes/esm2_accelerate/train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import hydra
2020
import torch
2121
import transformers
22-
from accelerate import Accelerator
22+
from accelerate import PartialState
2323
from omegaconf import DictConfig
2424
from transformers import AutoConfig, AutoModelForMaskedLM
2525
from transformers.trainer import Trainer
@@ -36,9 +36,14 @@
3636
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
3737
def main(args: DictConfig):
3838
"""Entrypoint."""
39-
# We need to initialize the Accelerator manually prior to creating our model, otherwise we won't end up setting the
40-
# current torch device and the model creation will all happen on a single GPU, typically leading to an OOM.
41-
_ = Accelerator()
39+
# Initialize Accelerate's distributed state early so torch device is set per process
40+
state = PartialState()
41+
logger.info(
42+
"Accelerate initialized (local_process_index=%s, num_processes=%s, device=%s)",
43+
state.local_process_index,
44+
state.num_processes,
45+
state.device,
46+
)
4247

4348
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
4449
config.max_seq_length = args.max_seq_length

0 commit comments

Comments
 (0)