33import pickle
44from abc import ABC , abstractmethod
55from datetime import timedelta
6- from typing import Any
6+ from typing import Any , Protocol
77
88import torch
99import torch .distributed as torch_dist
1010
1111
12+ class CommunicatorProtocol (Protocol ):
13+ def all_gather (self , * args : Any , ** kwargs : Any ) -> torch .Tensor : ...
14+
15+
16+ class CommGroup :
17+ def __init__ (self , comm_handle : int , ranks : list [int ]):
18+ self ._comm = comm_handle
19+ self .ranks = ranks
20+
21+ @property
22+ def handle (self ) -> int :
23+ return self ._comm
24+
25+ @property
26+ def ranks (self ) -> list [int ]:
27+ return self .ranks
28+
29+
30+ DistributedProcessGroup = torch_dist .ProcessGroup | CommGroup
31+
32+
1233class Distributed (ABC ):
1334 @abstractmethod
1435 def init_process_group (
@@ -24,7 +45,7 @@ def init_process_group(
2445 @abstractmethod
2546 def destroy_process_group (
2647 self ,
27- group : torch_dist . ProcessGroup | int | None = None ,
48+ group : DistributedProcessGroup | None = None ,
2849 ):
2950 raise NotImplementedError
3051
@@ -37,16 +58,17 @@ def all_gather_object(
3758 self ,
3859 object_list : list [Any ],
3960 obj : Any ,
40- group : torch_dist . ProcessGroup | int | None = None ,
61+ group : DistributedProcessGroup | None = None ,
4162 ):
4263 raise NotImplementedError
4364
4465 @abstractmethod
4566 def all_reduce (
4667 self ,
4768 tensor : torch .Tensor ,
48- op : torch_dist .ReduceOp ,
49- group : torch_dist .ProcessGroup | int | None = None ,
69+ op : torch_dist .ReduceOp .RedOpType ,
70+ group : DistributedProcessGroup | None = None ,
71+ ** kwargs ,
5072 ):
5173 raise NotImplementedError
5274
@@ -55,27 +77,89 @@ def broadcast(
5577 self ,
5678 tensor : torch .Tensor ,
5779 src : int ,
58- group : torch_dist .ProcessGroup | int | None = None ,
80+ group : DistributedProcessGroup | None = None ,
81+ ** kwargs ,
5982 ):
6083 raise NotImplementedError
6184
6285 @abstractmethod
6386 def barrier (
6487 self ,
65- group : torch_dist .ProcessGroup | int | None = None ,
88+ group : DistributedProcessGroup | None = None ,
89+ ** kwargs ,
6690 ):
6791 raise NotImplementedError
6892
6993 @abstractmethod
7094 def new_group (
7195 self ,
7296 ranks : list [int ],
97+ ** kwargs ,
7398 ):
7499 raise NotImplementedError
75100
76101
102+ class TorchBackend (Distributed ):
103+ def __init__ (self , backend_type : str ):
104+ self .backend_type = backend_type
105+
106+ def init_process_group (
107+ self ,
108+ host : str ,
109+ port : int ,
110+ rank : int ,
111+ world_size : int ,
112+ timeout : timedelta ,
113+ ):
114+ store = torch .distributed .TCPStore (
115+ host , port , world_size , timeout = timeout , is_master = (rank == 0 )
116+ )
117+ torch .distributed .init_process_group (
118+ backend = self .backend_type ,
119+ world_size = world_size ,
120+ rank = rank ,
121+ timeout = timeout ,
122+ store = store ,
123+ )
124+
125+ def destroy_process_group (self , group : DistributedProcessGroup | None = None ):
126+ torch_dist .destroy_process_group (group )
127+
128+ def is_initialized (self ) -> bool :
129+ return torch_dist .is_initialized ()
130+
131+ def all_gather_object (
132+ self , object_list : list [Any ], obj : Any , group : DistributedProcessGroup | None = None
133+ ):
134+ torch_dist .all_gather_object (object_list , obj , group )
135+
136+ def all_reduce (
137+ self ,
138+ tensor : torch .Tensor ,
139+ op : torch_dist .ReduceOp .RedOpType = torch_dist .ReduceOp .SUM ,
140+ group : DistributedProcessGroup | None = None ,
141+ ** kwargs ,
142+ ):
143+ torch_dist .all_reduce (tensor , op , group , ** kwargs )
144+
145+ def broadcast (
146+ self ,
147+ tensor : torch .Tensor ,
148+ src : int = 0 ,
149+ group : DistributedProcessGroup | None = None ,
150+ ** kwargs ,
151+ ):
152+ torch_dist .broadcast (tensor , src , group , ** kwargs )
153+
154+ def barrier (self , group : DistributedProcessGroup | None = None , ** kwargs ):
155+ torch_dist .barrier (group , ** kwargs )
156+
157+ def new_group (self , ranks : list [int ], ** kwargs ) -> DistributedProcessGroup | None :
158+ return torch_dist .new_group (ranks , ** kwargs )
159+
160+
77161# specific device instance
78- _BACKEND_INSTANCE = None
162+ _BACKEND_INSTANCE : Distributed = TorchBackend ( backend_type = "nccl" )
79163
80164_pickler = pickle .Pickler
81165_unpickler = pickle .Unpickler
@@ -112,7 +196,7 @@ def _flatten_for_scatter_gather(
112196
113197
114198def _common_all_gather_object (
115- comm : Any ,
199+ comm : CommunicatorProtocol ,
116200 device : torch .device ,
117201 world_size : int ,
118202 object_list : list [Any ],
@@ -144,83 +228,67 @@ def init_process_group(
144228 port : int ,
145229 rank : int ,
146230 world_size : int ,
231+ custom_dist : bool ,
147232 backend : str ,
148233 timeout : timedelta = timedelta (seconds = 300 ),
149234):
150235 global _BACKEND_INSTANCE
151236
152- mapping = {
153- "nccl" : ".nccl.DistributedNccl" ,
154- "hccl" : ".hccl.DistributedHccl" ,
155- }
237+ if not custom_dist :
238+ _BACKEND_INSTANCE = TorchBackend (backend_type = backend )
239+ else :
240+ mapping = {
241+ "nccl" : ".nccl.DistributedNccl" ,
242+ "hccl" : ".hccl.DistributedHccl" ,
243+ }
244+ if backend not in mapping :
245+ raise ValueError (f"Unsupported custom backend: { backend } " )
246+
247+ module_path , class_name = mapping [backend ].rsplit ("." , 1 )
248+ module = importlib .import_module (module_path , "checkpoint_engine.distributed" )
249+ backend_class = getattr (module , class_name )
250+ _BACKEND_INSTANCE = backend_class ()
156251
157- if backend not in mapping :
158- raise ValueError (f"Unsupported device type: { backend } " )
159-
160- module_path , class_name = mapping [backend ].rsplit ("." , 1 )
161- module = importlib .import_module (module_path , "checkpoint_engine.distributed" )
162- backend_class = getattr (module , class_name )
163-
164- _BACKEND_INSTANCE = backend_class ()
165252 _BACKEND_INSTANCE .init_process_group (host , port , rank , world_size , timeout )
166253
167254
168- def destroy_process_group (group : torch_dist .ProcessGroup | int | None = None ):
169- if _BACKEND_INSTANCE is None :
170- torch_dist .destroy_process_group (group )
171- return
255+ def destroy_process_group (group : DistributedProcessGroup | None = None ):
172256 _BACKEND_INSTANCE .destroy_process_group (group )
173257
174258
175259def is_initialized () -> bool :
176- if _BACKEND_INSTANCE is None :
177- return torch_dist .is_initialized ()
178260 return _BACKEND_INSTANCE .is_initialized ()
179261
180262
181263def all_gather_object (
182264 object_list : list [Any ],
183265 obj : Any ,
184- group : torch_dist . ProcessGroup | int | None = None ,
266+ group : DistributedProcessGroup | None = None ,
185267):
186- if _BACKEND_INSTANCE is None :
187- torch_dist .all_gather_object (object_list , obj , group )
188- return
189268 _BACKEND_INSTANCE .all_gather_object (object_list , obj , group )
190269
191270
192271def all_reduce (
193272 tensor : torch .Tensor ,
194- op : torch_dist .ReduceOp = torch_dist .ReduceOp .SUM ,
195- group : torch_dist . ProcessGroup | int | None = None ,
273+ op : torch_dist .ReduceOp . RedOpType = torch_dist .ReduceOp .SUM ,
274+ group : DistributedProcessGroup | None = None ,
196275 ** kwargs ,
197276):
198- if _BACKEND_INSTANCE is None :
199- torch_dist .all_reduce (tensor , op , group , ** kwargs )
200- return
201- _BACKEND_INSTANCE .all_reduce (tensor , op , group )
277+ _BACKEND_INSTANCE .all_reduce (tensor , op , group , ** kwargs )
202278
203279
204280def broadcast (
205281 tensor : torch .Tensor ,
206282 src : int = 0 ,
207- group : torch_dist . ProcessGroup | int | None = None ,
283+ group : DistributedProcessGroup | None = None ,
208284 ** kwargs ,
209285):
210- if _BACKEND_INSTANCE is None :
211- torch_dist .broadcast (tensor , src , group , ** kwargs )
212- return
213- _BACKEND_INSTANCE .broadcast (tensor , src , group )
286+ _BACKEND_INSTANCE .broadcast (tensor , src , group , ** kwargs )
214287
215288
216- def barrier (group : torch_dist .ProcessGroup | int | None = None , ** kwargs ):
217- if _BACKEND_INSTANCE is None :
218- torch_dist .barrier (group , ** kwargs )
219- return
220- _BACKEND_INSTANCE .barrier (group )
289+ def barrier (group : DistributedProcessGroup | None = None , ** kwargs ):
290+ _BACKEND_INSTANCE .barrier (group , ** kwargs )
221291
222292
223- def new_group (ranks : list [int ], ** kwargs ) -> torch_dist .ProcessGroup | int | None :
224- if _BACKEND_INSTANCE is None :
225- return torch_dist .new_group (ranks , ** kwargs )
226- return _BACKEND_INSTANCE .new_group (ranks )
293+ def new_group (ranks : list [int ], ** kwargs ) -> DistributedProcessGroup | None :
294+ return _BACKEND_INSTANCE .new_group (ranks , ** kwargs )
0 commit comments