Skip to content

Commit 8273ecd

Browse files
author
kip-cxj
committed
fix init tcp store
1 parent 65fd3ce commit 8273ecd

3 files changed

Lines changed: 17 additions & 18 deletions

File tree

checkpoint_engine/distributed/hccl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def barrier(self, group: CommGroup | None = None, **kwargs):
310310
self.pyhccl.all_reduce(data)
311311
current_stream().synchronize()
312312

313-
def new_group(self, ranks: list[int], **kwargs) -> CommGroup:
313+
def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
314314
assert self.initialized, "not initialized"
315315

316316
# ranks is None or []
@@ -319,8 +319,9 @@ def new_group(self, ranks: list[int], **kwargs) -> CommGroup:
319319
else:
320320
ranks.sort()
321321

322+
group: CommGroup = None
322323
if self.rank not in ranks:
323-
return
324+
return group
324325

325326
subcomm = self.pyhccl.create_subcomm(ranks)
326327
if subcomm:

checkpoint_engine/distributed/nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def barrier(self, group: CommGroup | None = None, **kwargs):
214214
self.pynccl.all_reduce(data)
215215
current_stream().synchronize()
216216

217-
def new_group(self, ranks: list[int], **kwargs) -> CommGroup:
217+
def new_group(self, ranks: list[int], **kwargs) -> CommGroup | None:
218218
assert self.initialized, "not initialized"
219219

220220
# ranks is None or []

checkpoint_engine/ps.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -553,22 +553,19 @@ def update(
553553
try:
554554
master_addr = os.getenv("MASTER_ADDR") or master_addr
555555
assert master_addr, "master_addr is required"
556-
if self._auto_pg:
557-
if not dist.is_initialized():
558-
self.init_process_group(
559-
timeout=timeout, master_addr=master_addr, master_port=master_port
560-
)
561-
manager_store = torch.distributed.distributed_c10d._get_default_store()
562-
else:
563-
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
564-
# If master_port is provided, use master_port+1 for barrier store
565-
manager_store = torch.distributed.TCPStore(
566-
master_addr,
567-
_get_master_port(master_port) + 1,
568-
self._world_size,
569-
timeout=timeout,
570-
is_master=self._rank == 0,
556+
if self._auto_pg and not dist.is_initialized():
557+
self.init_process_group(
558+
timeout=timeout, master_addr=master_addr, master_port=master_port
571559
)
560+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
561+
# If master_port is provided, use master_port+1 for barrier store
562+
manager_store = torch.distributed.TCPStore(
563+
master_addr,
564+
_get_master_port(master_port) + 1,
565+
self._world_size,
566+
timeout=timeout,
567+
is_master=self._rank == 0,
568+
)
572569
# if ranks is None or [], it will use fully broadcast to update to all ranks
573570
ranks_group = dist.new_group(ranks) if ranks else None
574571
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
@@ -583,6 +580,7 @@ def update(
583580
dist.destroy_process_group(ranks_group)
584581
if self._auto_pg and dist.is_initialized():
585582
dist.destroy_process_group()
583+
del manager_store
586584
self.device_manager.device_module.empty_cache()
587585
logger.info(
588586
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "

0 commit comments

Comments
 (0)