1+ from __future__ import annotations
2+
13import os
24import re
35import subprocess
46import warnings
5- from typing import Any , Callable , cast , Dict , List , Mapping , Optional , Tuple , Union
7+ from collections .abc import Callable , Mapping
8+ from typing import Any , cast
69
710import torch
811import torch .distributed as dist
@@ -40,19 +43,19 @@ class _NativeDistModel(ComputationModel):
4043 available_backends = tuple (name for name in [NCCL , GLOO , MPI ] if getattr (dist , f"is_{ name } _available" )())
4144
4245 @staticmethod
43- def create_from_context () -> Optional [ " _NativeDistModel" ] :
46+ def create_from_context () -> _NativeDistModel | None :
4447 if not (dist .is_available () and dist .is_initialized ()):
4548 return None
4649 return _NativeDistModel ()
4750
4851 @staticmethod
4952 def create_from_backend (
5053 backend : str ,
51- init_method : Optional [ str ] = None ,
52- world_size : Optional [ int ] = None ,
53- rank : Optional [ int ] = None ,
54+ init_method : str | None = None ,
55+ world_size : int | None = None ,
56+ rank : int | None = None ,
5457 ** kwargs : Any ,
55- ) -> " _NativeDistModel" :
58+ ) -> _NativeDistModel :
5659 if backend not in _NativeDistModel .available_backends :
5760 raise ValueError (f"Backend should be one of '{ _NativeDistModel .available_backends } '" )
5861
@@ -74,20 +77,20 @@ def create_from_backend(
7477
7578 def __init__ (
7679 self ,
77- backend : Optional [ str ] = None ,
78- timeout : Optional [ int ] = None ,
79- init_method : Optional [ str ] = None ,
80- world_size : Optional [ int ] = None ,
81- rank : Optional [ int ] = None ,
80+ backend : str | None = None ,
81+ timeout : int | None = None ,
82+ init_method : str | None = None ,
83+ world_size : int | None = None ,
84+ rank : int | None = None ,
8285 ** kwargs : Any ,
8386 ) -> None :
8487 """This is a private method. Please, use `create_from_backend` or `create_from_context`"""
8588 super ().__init__ ()
86- self ._env_backup : Optional [ Dict [ str , str ]] = None
87- self ._local_rank : Optional [ int ] = None
88- self ._master_port : Optional [ int ] = None
89- self ._master_addr : Optional [ str ] = None
90- self ._init_method : Optional [ str ] = None
89+ self ._env_backup : dict [ str , str ] | None = None
90+ self ._local_rank : int | None = None
91+ self ._master_port : int | None = None
92+ self ._master_addr : str | None = None
93+ self ._init_method : str | None = None
9194 if backend is not None :
9295 self ._create_from_backend (
9396 backend , timeout = timeout , init_method = init_method , world_size = world_size , rank = rank , ** kwargs
@@ -98,18 +101,18 @@ def __init__(
98101 def _create_from_backend (
99102 self ,
100103 backend : str ,
101- timeout : Optional [ int ] = None ,
102- init_method : Optional [ str ] = None ,
103- world_size : Optional [ int ] = None ,
104- rank : Optional [ int ] = None ,
104+ timeout : int | None = None ,
105+ init_method : str | None = None ,
106+ world_size : int | None = None ,
107+ rank : int | None = None ,
105108 ** kwargs : Any ,
106109 ) -> None :
107110 if backend == dist .Backend .NCCL and not torch .cuda .is_available ():
108111 raise RuntimeError ("Nccl backend is required but no cuda capable devices" )
109112 self ._backend = backend
110113 self .setup_env_vars (rank , world_size )
111114
112- init_pg_kwargs : Dict [str , Any ] = {}
115+ init_pg_kwargs : dict [str , Any ] = {}
113116 if timeout is not None :
114117 init_pg_kwargs ["timeout" ] = timeout
115118
@@ -156,7 +159,7 @@ def _compute_nproc_per_node(self) -> int:
156159 dist .destroy_process_group (gloo_group )
157160 return int (tensor .item ())
158161
159- def _get_all_hostnames (self ) -> List [ Tuple [str , ...]]:
162+ def _get_all_hostnames (self ) -> list [ tuple [str , ...]]:
160163 import socket
161164
162165 device = "cpu"
@@ -172,7 +175,7 @@ def _get_all_hostnames(self) -> List[Tuple[str, ...]]:
172175 return [tuple (t .cpu ().tolist ()) for t in out_t_names ]
173176
174177 @staticmethod
175- def _compute_node_and_local_ranks (rank : int , hostnames : List [ Tuple [str , ...]]) -> Tuple [int , int ]:
178+ def _compute_node_and_local_ranks (rank : int , hostnames : list [ tuple [str , ...]]) -> tuple [int , int ]:
176179 from collections import Counter
177180
178181 c : Counter = Counter (hostnames )
@@ -213,7 +216,7 @@ def _identify_local_rank(self) -> None:
213216 # use socket gethostname heuristic to determine number of nodes => local rank
214217 self ._local_rank = self ._compute_local_rank_via_hostname ()
215218
216- def setup_env_vars (self , rank : Optional [ int ] = None , world_size : Optional [ int ] = None ) -> None :
219+ def setup_env_vars (self , rank : int | None = None , world_size : int | None = None ) -> None :
217220 self ._env_backup = os .environ .copy ()
218221
219222 if "SLURM_JOB_ID" in os .environ :
@@ -253,7 +256,7 @@ def _setup_env_in_slurm(self) -> None:
253256 if k not in os .environ :
254257 raise RuntimeError (f"SLURM distributed configuration is missing '{ k } ' in env variables" )
255258
256- ddp_vars = _setup_ddp_vars_from_slurm_env (cast (Dict , os .environ ))
259+ ddp_vars = _setup_ddp_vars_from_slurm_env (cast (dict , os .environ ))
257260
258261 # define DDP env vars required by PTH:
259262 for key , value in ddp_vars .items ():
@@ -307,13 +310,13 @@ def _dist_worker_task_fn(
307310 local_rank : int ,
308311 backend : str ,
309312 fn : Callable ,
310- args : Tuple ,
313+ args : tuple ,
311314 kw_dict : Mapping ,
312315 world_size : int ,
313316 nprocs_per_node : int ,
314317 node_rank : int ,
315- master_addr : Optional [ str ] ,
316- master_port : Optional [ str ] ,
318+ master_addr : str | None ,
319+ master_port : str | None ,
317320 init_method : str ,
318321 kw : Any ,
319322 ) -> None :
@@ -326,8 +329,8 @@ def _dist_worker_task_fn(
326329 os .environ ["RANK" ] = str (rank )
327330 os .environ ["WORLD_SIZE" ] = str (world_size )
328331
329- arg_world_size : Optional [ int ] = world_size
330- arg_rank : Optional [ int ] = rank
332+ arg_world_size : int | None = world_size
333+ arg_rank : int | None = rank
331334 if init_method == "env://" :
332335 os .environ ["MASTER_ADDR" ] = str (master_addr )
333336 os .environ ["MASTER_PORT" ] = str (master_port )
@@ -348,15 +351,15 @@ def _dist_worker_task_fn(
348351 # pyrefly: ignore [bad-override]
349352 def spawn (
350353 fn : Callable ,
351- args : Tuple ,
352- kwargs_dict : Optional [ Mapping ] = None ,
354+ args : tuple ,
355+ kwargs_dict : Mapping | None = None ,
353356 nproc_per_node : int = 1 ,
354357 nnodes : int = 1 ,
355358 node_rank : int = 0 ,
356- master_addr : Optional [ str ] = None ,
357- master_port : Optional [ int ] = None ,
359+ master_addr : str | None = None ,
360+ master_port : int | None = None ,
358361 backend : str = "nccl" ,
359- init_method : Optional [ str ] = None ,
362+ init_method : str | None = None ,
360363 ** kwargs : Any ,
361364 ) -> None :
362365 world_size = nnodes * nproc_per_node
@@ -427,7 +430,7 @@ def _setup_group(self, group: Any) -> dist.ProcessGroup:
427430 "OR" : dist .ReduceOp .BOR ,
428431 }
429432
430- def _do_all_reduce (self , tensor : torch .Tensor , op : str = "SUM" , group : Optional [ Any ] = None ) -> torch .Tensor :
433+ def _do_all_reduce (self , tensor : torch .Tensor , op : str = "SUM" , group : Any | None = None ) -> torch .Tensor :
431434 if op not in self ._reduce_op_map :
432435 raise ValueError (f"Unsupported reduction operation: '{ op } '" )
433436 if group is not None :
@@ -440,7 +443,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
440443 dist .all_reduce (tensor , reduce_op )
441444 return tensor
442445
443- def _do_all_gather (self , tensor : torch .Tensor , group : Optional [ Any ] = None ) -> torch .Tensor :
446+ def _do_all_gather (self , tensor : torch .Tensor , group : Any | None = None ) -> torch .Tensor :
444447 if group is not None :
445448 group = self ._setup_group (group )
446449 if self ._rank_not_in_group (group ):
@@ -459,7 +462,7 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
459462 dist .all_gather (output , tensor )
460463 return torch .cat (output , dim = 0 )
461464
462- def _do_all_gather_object (self , tensor : Any , group : Optional [ Any ] = None ) -> List [Any ]:
465+ def _do_all_gather_object (self , tensor : Any , group : Any | None = None ) -> list [Any ]:
463466 if Version (torch .__version__ ) < Version ("1.7.0" ):
464467 raise RuntimeError (
465468 "Current torch version does not implement dist.all_gather_object. "
@@ -482,7 +485,7 @@ def _do_all_gather_object(self, tensor: Any, group: Optional[Any] = None) -> Lis
482485
483486 return output
484487
485- def _do_new_group (self , ranks : List [int ], ** kwargs : Any ) -> Any :
488+ def _do_new_group (self , ranks : list [int ], ** kwargs : Any ) -> Any :
486489 return dist .new_group (ranks = ranks , ** kwargs )
487490
488491 def _do_broadcast (self , tensor : torch .Tensor , src : int ) -> torch .Tensor :
@@ -492,10 +495,10 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
492495 def barrier (self ) -> None :
493496 dist .barrier ()
494497
495- def _rank_not_in_group (self , group : Optional [ Any ] ) -> bool :
498+ def _rank_not_in_group (self , group : Any | None ) -> bool :
496499 return dist ._rank_not_in_group (group )
497500
498- def _expand_hostlist (nodelist : str ) -> List [str ]:
501+ def _expand_hostlist (nodelist : str ) -> list [str ]:
499502 """Expand a compressed hostlist string and returns all hosts listed.
500503
501504 Source : https://github.com/LLNL/py-hostlist/blob/master/hostlist/hostlist.py
@@ -561,7 +564,7 @@ def _expand_hostlist(nodelist: str) -> List[str]:
561564
562565 return result_hostlist
563566
564- def _setup_ddp_vars_from_slurm_env (environ : Dict [str , str ]) -> Dict [str , Union [ str , int ] ]:
567+ def _setup_ddp_vars_from_slurm_env (environ : dict [str , str ]) -> dict [str , str | int ]:
565568 """Method to setup DDP env vars required by PyTorch from SLURM env"""
566569 # 1) Tools like enroot can have hooks to translate slurm env vars to RANK, LOCAL_RANK, WORLD_SIZE etc
567570 # See https://github.com/NVIDIA/enroot/blob/v3.1.0/conf/hooks/extra/50-slurm-pytorch.sh
@@ -572,7 +575,7 @@ def _setup_ddp_vars_from_slurm_env(environ: Dict[str, str]) -> Dict[str, Union[s
572575 # To cover case 2), let's check that defined RANK >= SLURM_PROCID, LOCAL_RANK >= SLURM_LOCALID,
573576 # WORLD_SIZE >= SLURM_NTASKS, SLURM_JOB_NUM_NODES == 1
574577
575- ddp_vars : Dict [str , Union [ str , int , None ] ] = {
578+ ddp_vars : dict [str , str | int | None ] = {
576579 "RANK" : int (environ ["SLURM_PROCID" ]),
577580 "LOCAL_RANK" : int (environ ["SLURM_LOCALID" ]),
578581 "WORLD_SIZE" : int (environ ["SLURM_NTASKS" ]),
@@ -646,4 +649,4 @@ def _setup_ddp_vars_from_slurm_env(environ: Dict[str, str]) -> Dict[str, Union[s
646649 slurm_port = slurm_port [- 4 :]
647650 ddp_vars ["MASTER_PORT" ] = int (slurm_port ) + 15000
648651
649- return cast (Dict [str , Union [ str , int ] ], ddp_vars )
652+ return cast (dict [str , str | int ], ddp_vars )
0 commit comments