Skip to content

Commit 12f525e

Browse files
author
yexin
committed
fix bugs
1 parent 21ab991 commit 12f525e

2 files changed

Lines changed: 77 additions & 117 deletions

File tree

checkpoint_engine/distributed.py

Lines changed: 73 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import pickle
6+
from datetime import timedelta
67
from enum import Enum
78
from typing import Any, List, Optional
89

@@ -79,7 +80,7 @@ def init_process_group(
7980
port: int,
8081
rank: int,
8182
world_size: int,
82-
timeout: int = 300,
83+
timeout: timedelta = timedelta(seconds=300),
8384
**kwargs,
8485
):
8586
self._host = host
@@ -88,7 +89,9 @@ def init_process_group(
8889
self._world_size = world_size
8990
self._device = torch.device("cuda", rank)
9091

91-
self.pg = StatelessProcessGroup.create(host, port, rank, world_size, store_timeout=timeout)
92+
self.pg = StatelessProcessGroup.create(
93+
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())
94+
)
9295

9396
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
9497

@@ -186,6 +189,27 @@ def new_group(self, ranks):
186189
)
187190
from vllm_ascend.utils import current_stream
188191

192+
class HcclCommConfig(ctypes.Structure):
193+
_fields_ = [
194+
("size", ctypes.c_size_t),
195+
("magic_word", ctypes.c_uint32),
196+
("version", ctypes.c_uint32),
197+
("reserved", ctypes.c_uint64),
198+
("hccl_buffer_size", ctypes.c_uint32),
199+
("hccl_deterministic", ctypes.c_uint32),
200+
("hccl_comm_name", ctypes.c_char * 128),
201+
("hccl_udi", ctypes.c_char * 128),
202+
("hccl_op_expansion_mode", ctypes.c_uint32),
203+
("hccl_rdma_traffic_class", ctypes.c_uint32),
204+
("hccl_rdma_service_level", ctypes.c_uint32),
205+
("hcll_world_rank_id", ctypes.c_uint32),
206+
("hccl_job_id", ctypes.c_uint64),
207+
("comm_engine", ctypes.c_int32),
208+
("thread_num", ctypes.c_uint32),
209+
("notify_num_per_thread", ctypes.c_uint32),
210+
("acl_graph_zero_copy_enable", ctypes.c_uint8),
211+
]
212+
189213
orig_exported_functions = HCCLLibrary.exported_functions
190214
extended_functions = [
191215
# HcclResult HcclAllGather(
@@ -217,7 +241,7 @@ def new_group(self, ranks):
217241
ctypes.POINTER(ctypes.c_uint32),
218242
ctypes.c_uint64,
219243
ctypes.c_uint32,
220-
ctypes.POINTER(hcclUniqueId),
244+
ctypes.POINTER(HcclCommConfig),
221245
ctypes.POINTER(hcclComm_t),
222246
],
223247
),
@@ -228,27 +252,6 @@ def hccl_all_gather(self, send_buf, recv_buf, count, data_type, comm, stream):
228252
self._funcs["HcclAllGather"](send_buf, recv_buf, count, data_type, comm, stream)
229253
)
230254

231-
class HcclCommConfig(ctypes.Structure):
232-
_fields_ = [
233-
("size", ctypes.c_size_t),
234-
("magic_word", ctypes.c_uint32),
235-
("version", ctypes.c_uint32),
236-
("reserved", ctypes.c_uint64),
237-
("hccl_buffer_size", ctypes.c_uint32),
238-
("hccl_deterministic", ctypes.c_uint32),
239-
("hccl_comm_name", ctypes.c_char * 128),
240-
("hccl_udi", ctypes.c_char * 128),
241-
("hccl_op_expansion_mode", ctypes.c_uint32),
242-
("hccl_rdma_traffic_class", ctypes.c_uint32),
243-
("hccl_rdma_service_level", ctypes.c_uint32),
244-
("hcll_world_rank_id", ctypes.c_uint32),
245-
("hccl_job_id", ctypes.c_uint64),
246-
("comm_engine", ctypes.c_int32),
247-
("thread_num", ctypes.c_uint32),
248-
("notify_num_per_thread", ctypes.c_uint32),
249-
("acl_graph_zero_copy_enable", ctypes.c_uint8),
250-
]
251-
252255
def hccl_create_subcomm_config(
253256
self, comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
254257
):
@@ -274,55 +277,13 @@ def hccl_create_subcomm_config(
274277
class PyHcclCommunicatorEx(PyHcclCommunicator):
275278
def __init__(self, group, device):
276279
super().__init__(group, device)
277-
self.subcomms = {}
278280
self.subcomm_id = 1
279281

280-
def destroy_comm(self):
281-
self.hccl.hcclCommDestroy(self.comm)
282-
283-
def all_reduce(
284-
self,
285-
in_tensor: torch.Tensor,
286-
op: ReduceOp = ReduceOp.SUM,
287-
stream=None,
288-
) -> torch.Tensor:
289-
if self.disabled:
290-
return None
291-
assert in_tensor.device == self.device, (
292-
f"this hccl communicator is created to work on {self.device}, "
293-
f"but the input tensor is on {in_tensor.device}"
294-
)
295-
out_tensor = torch.empty_like(in_tensor)
296-
if stream is None:
297-
stream = current_stream()
298-
self.hccl.hcclAllReduce(
299-
buffer_type(in_tensor.data_ptr()),
300-
buffer_type(out_tensor.data_ptr()),
301-
in_tensor.numel(),
302-
hcclDataTypeEnum.from_torch(in_tensor.dtype),
303-
hcclRedOpTypeEnum.from_torch(op),
304-
self.comm, # todo
305-
aclrtStream_t(stream.npu_stream),
306-
)
307-
return out_tensor
308-
309-
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
310-
if self.disabled:
311-
return None
312-
assert tensor.device == self.device, (
313-
f"this hccl communicator is created to work on {self.device}, "
314-
f"but the input tensor is on {tensor.device}"
315-
)
316-
if stream is None:
317-
stream = current_stream()
318-
self.hccl.hcclBroadcast(
319-
buffer_type(tensor.data_ptr()),
320-
tensor.numel(),
321-
hcclDataTypeEnum.from_torch(tensor.dtype),
322-
src,
323-
self.comm, # todo
324-
aclrtStream_t(stream.npu_stream),
325-
)
282+
def destroy_comm(self, comm=None):
283+
if comm:
284+
self.hccl.hcclCommDestroy(comm)
285+
else:
286+
self.hccl.hcclCommDestroy(self.comm)
326287

327288
def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=None):
328289
if self.disabled:
@@ -343,10 +304,7 @@ def all_gather(self, out_tensor: torch.Tensor, in_tensor: torch.Tensor, stream=N
343304
)
344305
return out_tensor
345306

346-
def create_subcomm(
347-
self,
348-
ranks,
349-
):
307+
def create_subcomm(self, ranks):
350308
comm_config = HcclCommConfig(
351309
size=312,
352310
magic_word=0xF0F0F0F0,
@@ -375,7 +333,6 @@ def create_subcomm(
375333
subcomm = self.hccl.hcclCreateSubCommConfig(
376334
self.comm, ranks_size, c_rank_ids, subcomm_id, subcomm_rank, comm_config
377335
)
378-
self.subcomms[subcomm_id] = subcomm
379336
self.subcomm_id += 1
380337
return subcomm
381338

@@ -391,7 +348,7 @@ def init_process_group(
391348
port: int,
392349
rank: int,
393350
world_size: int,
394-
timeout: int = 300,
351+
timeout: timedelta = timedelta(seconds=300),
395352
**kwargs,
396353
):
397354
self._host = host
@@ -401,13 +358,15 @@ def init_process_group(
401358
self._device = torch.device("npu", rank)
402359

403360
self.pg = StatelessProcessGroup.create(
404-
host, port, rank, world_size, store_timeout=timeout
361+
host, port, rank, world_size, store_timeout=int(timeout.total_seconds())
405362
)
406363
self.pyhccl = PyHcclCommunicatorEx(group=self.pg, device=self._device)
364+
self._comm = self.pyhccl.comm
407365

408366
def destroy_process_group(self, group=None):
409367
if group in self.sub_groups:
410-
group.pyhccl.destroy_comm()
368+
subcomm = ctypes.c_void_p(group)
369+
self.pyhccl.destroy_comm(subcomm)
411370
del self.sub_groups[group]
412371
return
413372

@@ -422,69 +381,70 @@ def is_initialized(self) -> bool:
422381
def all_gather_object(self, object_list: list[Any], obj: Any, group=None):
423382
if group:
424383
assert group in self.sub_groups, "invalid sub_group"
425-
pyhccl = group.pyhccl
426-
else:
427-
pyhccl = self.pyhccl
428-
_common_all_gather_object(pyhccl, self._device, self._world_size, object_list, obj)
384+
subcomm = ctypes.c_void_p(group)
385+
self.pyhccl.comm = subcomm
386+
387+
_common_all_gather_object(self.pyhccl, self._device, self._world_size, object_list, obj)
429388
current_stream().synchronize()
430389

390+
if group:
391+
self.pyhccl.comm = self._comm
392+
431393
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, group=None):
432394
if group:
433395
assert group in self.sub_groups, "invalid sub_group"
434-
pyhccl = group.pyhccl
435-
else:
436-
pyhccl = self.pyhccl
396+
subcomm = ctypes.c_void_p(group)
397+
self.pyhccl.comm = subcomm
437398

438-
out_tensor = pyhccl.all_reduce(tensor, op)
399+
out_tensor = self.pyhccl.all_reduce(tensor, op)
439400
current_stream().synchronize()
440401
tensor.copy_(out_tensor)
441402

403+
if group:
404+
self.pyhccl.comm = self._comm
405+
442406
def broadcast(self, tensor: torch.Tensor, src=None, group=None):
443407
if group:
444408
assert group in self.sub_groups, "invalid sub_group"
445409
assert src in self.sub_groups[group], "src rank not in group"
446-
pyhccl = group.pyhccl
447-
# src is rank id in global world
410+
subcomm = ctypes.c_void_p(group)
411+
self.pyhccl.comm = subcomm
412+
# convert src rank id in default world to subcomm
448413
src = self.sub_groups[group].index(src)
449-
else:
450-
pyhccl = self.pyhccl
451414

452-
pyhccl.broadcast(tensor, src)
415+
self.pyhccl.broadcast(tensor, src)
453416
current_stream().synchronize()
454417

418+
if group:
419+
self.pyhccl.comm = self._comm
420+
455421
def barrier(self, group=None):
456422
if group:
457423
assert group in self.sub_groups, "invalid sub_group"
458-
pyhccl = group.pyhccl
459-
else:
460-
pyhccl = self.pyhccl
424+
subcomm = ctypes.c_void_p(group)
425+
self.pyhccl.comm = subcomm
461426

462427
data = torch.zeros(1, device=self._rank)
463-
pyhccl.all_reduce(data)
428+
self.pyhccl.all_reduce(data)
464429
current_stream().synchronize()
465430

431+
if group:
432+
self.pyhccl.comm = self._comm
433+
466434
def new_group(self, ranks):
467-
# ranks is None or []
435+
# if ranks is None or [], using the world instead
468436
if not ranks:
469-
return self
470-
471-
host = self._host
472-
port = self._port
473-
rank = self._rank
437+
ranks = list(range(self._world_size))
474438

475-
if rank not in ranks:
439+
if self._rank not in ranks:
476440
return
477441

478-
new_rank = ranks.index(rank)
479-
new_world_size = len(ranks)
480-
481-
new_dist = DistributedHccl()
482-
new_dist.init_process_group(
483-
host, port + 10, new_rank, new_world_size
484-
) # todo host maybe incorrect
485-
self.sub_groups[new_dist] = ranks
486-
487-
return new_dist
442+
subcomm = self.pyhccl.create_subcomm(ranks)
443+
value = 0
444+
if subcomm:
445+
value = subcomm.value
446+
self.sub_groups[value] = ranks
447+
return value
488448

489449
except ImportError as e:
490450
pass

checkpoint_engine/ps.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -893,9 +893,9 @@ def __init__(
893893
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
894894
self._mem_fraction = mem_fraction or 0.9
895895
if self.device_manager.backend == "nccl":
896-
self.dist = DistributedNccl
896+
self.dist = DistributedNccl()
897897
elif self.device_manager.backend == "hccl":
898-
self.dist = DistributedHccl
898+
self.dist = DistributedHccl()
899899
else:
900900
self.dist = torch.distributed
901901

@@ -1141,7 +1141,7 @@ def init_process_group(
11411141
"""
11421142
master_addr = master_addr or os.getenv("MASTER_ADDR")
11431143
assert master_addr, "master_addr is required"
1144-
store = self.dist.TCPStore(
1144+
store = dist.TCPStore(
11451145
master_addr,
11461146
_get_master_port(master_port),
11471147
self._world_size,
@@ -1474,7 +1474,7 @@ def _update_per_bucket(
14741474
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
14751475
)
14761476
ret_code.fill_(1)
1477-
self.dist.all_reduce(ret_code, op=self.dist.ReduceOp.SUM, group=ranks_group)
1477+
self.dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
14781478
self.device_manager.device_module.synchronize()
14791479
if ret_code.item() != 0:
14801480
# quit early if any rank failed

0 commit comments

Comments
 (0)