File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -41,15 +41,15 @@ def main(args):
4141 rank = int (os .environ ["LOCAL_RANK" ])
4242 if torch .accelerator .is_available ():
4343 device_type = torch .accelerator .current_accelerator ()
44- device : torch . device = torch .device (f"{ device_type } :{ rank } " )
44+ device = torch .device (f"{ device_type } :{ rank } " )
4545 torch .accelerator .device_index (rank )
4646 print (f"Running on rank { rank } on device { device } " )
47- backend = torch .distributed .get_default_backend_for_device (device )
48- torch .distributed .init_process_group (backend = backend , device_id = device )
4947 else :
5048 device = torch .device ("cpu" )
5149 print (f"Running on device { device } " )
52- torch .distributed .init_process_group (backend = "gloo" , device_id = device )
50+
51+ backend = torch .distributed .get_default_backend_for_device (device )
52+ torch .distributed .init_process_group (backend = backend , device_id = device )
5353
5454 torch .manual_seed (0 )
5555 vocab_size = 1024
You can’t perform that action at this time.
0 commit comments