77from torch .distributed .fsdp import fully_shard , MixedPrecisionPolicy
88from utils import inspect_mixed_precision , inspect_model
99
10+ def verify_min_gpu_count (min_gpus : int = 2 ) -> bool :
11+ """ verification that we have at least 2 gpus to run dist examples """
12+ has_gpu = torch .accelerator .is_available ()
13+ gpu_count = torch .accelerator .device_count ()
14+ return has_gpu and gpu_count >= min_gpus
1015
1116def set_modules_to_forward_prefetch (model , num_to_forward_prefetch ):
1217 for i , layer in enumerate (model .layers ):
@@ -30,9 +35,18 @@ def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
3035
3136def main (args ):
3237 rank = int (os .environ ["LOCAL_RANK" ])
33- device = torch .device (f"cuda:{ rank } " )
34- torch .cuda .set_device (device )
35- torch .distributed .init_process_group (backend = "nccl" , device_id = device )
38+ if torch .accelerator .is_available ():
39+ device_type = torch .accelerator .current_accelerator ()
40+ device : torch .device = torch .device (f"{ device_type } :{ rank } " )
41+ torch .accelerator .device_index (rank )
42+ print (f"Running on rank { rank } on device { device } " )
43+ backend = torch .distributed .get_default_backend_for_device (device )
44+ torch .distributed .init_process_group (backend = backend , device_id = device )
45+ else :
46+ device = torch .device ("cpu" )
47+ print (f"Running on device { device } " )
48+ torch .distributed .init_process_group (backend = "gloo" , device_id = device )
49+
3650 torch .manual_seed (0 )
3751 vocab_size = 1024
3852 batch_size = 32
@@ -64,7 +78,7 @@ def main(args):
6478
6579 checkpointer = Checkpointer ("checkpoints" , dcp_api = args .dcp_api )
6680 if checkpointer .last_training_time is None :
67- model .to_empty (device = "cuda" )
81+ model .to_empty (device = device )
6882 model .reset_parameters ()
6983 else :
7084 checkpointer .load_model (model )
@@ -96,4 +110,8 @@ def main(args):
96110 parser .add_argument ("--mixed-precision" , action = "store_true" , default = False )
97111 parser .add_argument ("--dcp-api" , action = "store_true" , default = False )
98112 args = parser .parse_args ()
113+ _min_gpu_count = 2
114+ if not verify_min_gpu_count (min_gpus = _min_gpu_count ):
115+ print (f"Unable to locate sufficient { _min_gpu_count } gpus to run this example. Exiting." )
116+ exit ()
99117 main (args )
0 commit comments