Skip to content

Commit 5ab13a1

Browse files
author
kip-cxj
committed
add abstraction layer etc.
1 parent 22f2b32 commit 5ab13a1

4 files changed

Lines changed: 244 additions & 200 deletions

File tree

checkpoint_engine/distributed/base.py

Lines changed: 119 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,33 @@
33
import pickle
44
from abc import ABC, abstractmethod
55
from datetime import timedelta
6-
from typing import Any
6+
from typing import Any, Protocol
77

88
import torch
99
import torch.distributed as torch_dist
1010

1111

12+
class CommunicatorProtocol(Protocol):
13+
def all_gather(self, *args: Any, **kwargs: Any) -> torch.Tensor: ...
14+
15+
16+
class CommGroup:
17+
def __init__(self, comm_handle: int, ranks: list[int]):
18+
self._comm = comm_handle
19+
self.ranks = ranks
20+
21+
@property
22+
def handle(self) -> int:
23+
return self._comm
24+
25+
@property
26+
def ranks(self) -> list[int]:
27+
return self.ranks
28+
29+
30+
DistributedProcessGroup = torch_dist.ProcessGroup | CommGroup
31+
32+
1233
class Distributed(ABC):
1334
@abstractmethod
1435
def init_process_group(
@@ -24,7 +45,7 @@ def init_process_group(
2445
@abstractmethod
2546
def destroy_process_group(
2647
self,
27-
group: torch_dist.ProcessGroup | int | None = None,
48+
group: DistributedProcessGroup | None = None,
2849
):
2950
raise NotImplementedError
3051

@@ -37,16 +58,17 @@ def all_gather_object(
3758
self,
3859
object_list: list[Any],
3960
obj: Any,
40-
group: torch_dist.ProcessGroup | int | None = None,
61+
group: DistributedProcessGroup | None = None,
4162
):
4263
raise NotImplementedError
4364

4465
@abstractmethod
4566
def all_reduce(
4667
self,
4768
tensor: torch.Tensor,
48-
op: torch_dist.ReduceOp,
49-
group: torch_dist.ProcessGroup | int | None = None,
69+
op: torch_dist.ReduceOp.RedOpType,
70+
group: DistributedProcessGroup | None = None,
71+
**kwargs,
5072
):
5173
raise NotImplementedError
5274

@@ -55,27 +77,89 @@ def broadcast(
5577
self,
5678
tensor: torch.Tensor,
5779
src: int,
58-
group: torch_dist.ProcessGroup | int | None = None,
80+
group: DistributedProcessGroup | None = None,
81+
**kwargs,
5982
):
6083
raise NotImplementedError
6184

6285
@abstractmethod
6386
def barrier(
6487
self,
65-
group: torch_dist.ProcessGroup | int | None = None,
88+
group: DistributedProcessGroup | None = None,
89+
**kwargs,
6690
):
6791
raise NotImplementedError
6892

6993
@abstractmethod
7094
def new_group(
7195
self,
7296
ranks: list[int],
97+
**kwargs,
7398
):
7499
raise NotImplementedError
75100

76101

102+
class TorchBackend(Distributed):
103+
def __init__(self, backend_type: str):
104+
self.backend_type = backend_type
105+
106+
def init_process_group(
107+
self,
108+
host: str,
109+
port: int,
110+
rank: int,
111+
world_size: int,
112+
timeout: timedelta,
113+
):
114+
store = torch.distributed.TCPStore(
115+
host, port, world_size, timeout=timeout, is_master=(rank == 0)
116+
)
117+
torch.distributed.init_process_group(
118+
backend=self.backend_type,
119+
world_size=world_size,
120+
rank=rank,
121+
timeout=timeout,
122+
store=store,
123+
)
124+
125+
def destroy_process_group(self, group: DistributedProcessGroup | None = None):
126+
torch_dist.destroy_process_group(group)
127+
128+
def is_initialized(self) -> bool:
129+
return torch_dist.is_initialized()
130+
131+
def all_gather_object(
132+
self, object_list: list[Any], obj: Any, group: DistributedProcessGroup | None = None
133+
):
134+
torch_dist.all_gather_object(object_list, obj, group)
135+
136+
def all_reduce(
137+
self,
138+
tensor: torch.Tensor,
139+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
140+
group: DistributedProcessGroup | None = None,
141+
**kwargs,
142+
):
143+
torch_dist.all_reduce(tensor, op, group, **kwargs)
144+
145+
def broadcast(
146+
self,
147+
tensor: torch.Tensor,
148+
src: int = 0,
149+
group: DistributedProcessGroup | None = None,
150+
**kwargs,
151+
):
152+
torch_dist.broadcast(tensor, src, group, **kwargs)
153+
154+
def barrier(self, group: DistributedProcessGroup | None = None, **kwargs):
155+
torch_dist.barrier(group, **kwargs)
156+
157+
def new_group(self, ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
158+
return torch_dist.new_group(ranks, **kwargs)
159+
160+
77161
# specific device instance
78-
_BACKEND_INSTANCE = None
162+
_BACKEND_INSTANCE: Distributed = TorchBackend(backend_type="nccl")
79163

80164
_pickler = pickle.Pickler
81165
_unpickler = pickle.Unpickler
@@ -112,7 +196,7 @@ def _flatten_for_scatter_gather(
112196

113197

114198
def _common_all_gather_object(
115-
comm: Any,
199+
comm: CommunicatorProtocol,
116200
device: torch.device,
117201
world_size: int,
118202
object_list: list[Any],
@@ -144,83 +228,67 @@ def init_process_group(
144228
port: int,
145229
rank: int,
146230
world_size: int,
231+
custom_dist: bool,
147232
backend: str,
148233
timeout: timedelta = timedelta(seconds=300),
149234
):
150235
global _BACKEND_INSTANCE
151236

152-
mapping = {
153-
"nccl": ".nccl.DistributedNccl",
154-
"hccl": ".hccl.DistributedHccl",
155-
}
237+
if not custom_dist:
238+
_BACKEND_INSTANCE = TorchBackend(backend_type=backend)
239+
else:
240+
mapping = {
241+
"nccl": ".nccl.DistributedNccl",
242+
"hccl": ".hccl.DistributedHccl",
243+
}
244+
if backend not in mapping:
245+
raise ValueError(f"Unsupported custom backend: {backend}")
246+
247+
module_path, class_name = mapping[backend].rsplit(".", 1)
248+
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
249+
backend_class = getattr(module, class_name)
250+
_BACKEND_INSTANCE = backend_class()
156251

157-
if backend not in mapping:
158-
raise ValueError(f"Unsupported device type: {backend}")
159-
160-
module_path, class_name = mapping[backend].rsplit(".", 1)
161-
module = importlib.import_module(module_path, "checkpoint_engine.distributed")
162-
backend_class = getattr(module, class_name)
163-
164-
_BACKEND_INSTANCE = backend_class()
165252
_BACKEND_INSTANCE.init_process_group(host, port, rank, world_size, timeout)
166253

167254

168-
def destroy_process_group(group: torch_dist.ProcessGroup | int | None = None):
169-
if _BACKEND_INSTANCE is None:
170-
torch_dist.destroy_process_group(group)
171-
return
255+
def destroy_process_group(group: DistributedProcessGroup | None = None):
172256
_BACKEND_INSTANCE.destroy_process_group(group)
173257

174258

175259
def is_initialized() -> bool:
176-
if _BACKEND_INSTANCE is None:
177-
return torch_dist.is_initialized()
178260
return _BACKEND_INSTANCE.is_initialized()
179261

180262

181263
def all_gather_object(
182264
object_list: list[Any],
183265
obj: Any,
184-
group: torch_dist.ProcessGroup | int | None = None,
266+
group: DistributedProcessGroup | None = None,
185267
):
186-
if _BACKEND_INSTANCE is None:
187-
torch_dist.all_gather_object(object_list, obj, group)
188-
return
189268
_BACKEND_INSTANCE.all_gather_object(object_list, obj, group)
190269

191270

192271
def all_reduce(
193272
tensor: torch.Tensor,
194-
op: torch_dist.ReduceOp = torch_dist.ReduceOp.SUM,
195-
group: torch_dist.ProcessGroup | int | None = None,
273+
op: torch_dist.ReduceOp.RedOpType = torch_dist.ReduceOp.SUM,
274+
group: DistributedProcessGroup | None = None,
196275
**kwargs,
197276
):
198-
if _BACKEND_INSTANCE is None:
199-
torch_dist.all_reduce(tensor, op, group, **kwargs)
200-
return
201-
_BACKEND_INSTANCE.all_reduce(tensor, op, group)
277+
_BACKEND_INSTANCE.all_reduce(tensor, op, group, **kwargs)
202278

203279

204280
def broadcast(
205281
tensor: torch.Tensor,
206282
src: int = 0,
207-
group: torch_dist.ProcessGroup | int | None = None,
283+
group: DistributedProcessGroup | None = None,
208284
**kwargs,
209285
):
210-
if _BACKEND_INSTANCE is None:
211-
torch_dist.broadcast(tensor, src, group, **kwargs)
212-
return
213-
_BACKEND_INSTANCE.broadcast(tensor, src, group)
286+
_BACKEND_INSTANCE.broadcast(tensor, src, group, **kwargs)
214287

215288

216-
def barrier(group: torch_dist.ProcessGroup | int | None = None, **kwargs):
217-
if _BACKEND_INSTANCE is None:
218-
torch_dist.barrier(group, **kwargs)
219-
return
220-
_BACKEND_INSTANCE.barrier(group)
289+
def barrier(group: DistributedProcessGroup | None = None, **kwargs):
290+
_BACKEND_INSTANCE.barrier(group, **kwargs)
221291

222292

223-
def new_group(ranks: list[int], **kwargs) -> torch_dist.ProcessGroup | int | None:
224-
if _BACKEND_INSTANCE is None:
225-
return torch_dist.new_group(ranks, **kwargs)
226-
return _BACKEND_INSTANCE.new_group(ranks)
293+
def new_group(ranks: list[int], **kwargs) -> DistributedProcessGroup | None:
294+
return _BACKEND_INSTANCE.new_group(ranks, **kwargs)

0 commit comments

Comments
 (0)