Skip to content

Commit c3badb4

Browse files
author
yexin
committed
use dist.device instead of dist.rank
1 parent 8218a3b commit c3badb4

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

checkpoint_engine/distributed_hccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def barrier(group=None):
298298
subcomm = ctypes.c_void_p(group)
299299
dist.pyhccl.comm = subcomm
300300

301-
data = torch.zeros(1, device=dist.rank)
301+
data = torch.zeros(1, device=dist.device)
302302
dist.pyhccl.all_reduce(data)
303303
current_stream().synchronize()
304304

checkpoint_engine/distributed_nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def barrier(group=None):
270270
newcomm = ctypes.c_void_p(group)
271271
dist.pynccl.comm = newcomm
272272

273-
data = torch.zeros(1, device=dist.rank)
273+
data = torch.zeros(1, device=dist.device)
274274
dist.pynccl.all_reduce(data)
275275
current_stream().synchronize()
276276

0 commit comments

Comments
 (0)