File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1919import hydra
2020import torch
2121import transformers
22- from accelerate import Accelerator
22+ from accelerate import PartialState
2323from omegaconf import DictConfig
2424from transformers import AutoConfig , AutoModelForMaskedLM
2525from transformers .trainer import Trainer
3636@hydra .main (config_path = "hydra_config" , config_name = "L0_sanity" , version_base = "1.2" )
3737def main (args : DictConfig ):
3838 """Entrypoint."""
39- # We need to initialize the Accelerator manually prior to creating our model, otherwise we won't end up setting the
40- # current torch device and the model creation will all happen on a single GPU, typically leading to an OOM.
41- _ = Accelerator ()
39+ # Initialize Accelerate's distributed state early so torch device is set per process
40+ state = PartialState ()
41+ logger .info (
42+ "Accelerate initialized (local_process_index=%s, num_processes=%s, device=%s)" ,
43+ state .local_process_index ,
44+ state .num_processes ,
45+ state .device ,
46+ )
4247
4348 config = AutoConfig .from_pretrained (args .model_tag , trust_remote_code = True )
4449 config .max_seq_length = args .max_seq_length
You can’t perform that action at this time.
0 commit comments