Skip to content

Commit 8218a3b

Browse files
author
yexin
committed
add missing global statement
1 parent f2c0ae8 commit 8218a3b

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

checkpoint_engine/ps.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,12 @@ def __init__(
196196
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
197197
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
198198
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
199+
global dist
199200
if device_type == "npu" and self.device_manager.device_type == "npu":
200-
import checkpoint_engine.distributed_hccl
201-
dist = checkpoint_engine.distributed_hccl
201+
import checkpoint_engine.distributed_hccl as dist
202202
self._device_type = "npu"
203203
elif device_type == "cuda" and self.device_manager.device_type == "cuda":
204-
import checkpoint_engine.distributed_nccl
205-
dist = checkpoint_engine.distributed_nccl
204+
import checkpoint_engine.distributed_nccl as dist
206205
self._device_type = "cuda"
207206
else:
208207
self._device_type = "torch"

0 commit comments

Comments
 (0)