File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -196,13 +196,12 @@ def __init__(
196196 self ._local_rdma_devices : dict [str , set [int ]] = defaultdict (set )
197197 self ._remote_rdma_devices : dict [str , set [int ]] = defaultdict (set )
198198 self ._mem_fraction = mem_fraction or float (os .getenv ("PS_MEM_FRACTION" , "0.9" ))
199+ global dist
199200 if device_type == "npu" and self .device_manager .device_type == "npu" :
200- import checkpoint_engine .distributed_hccl
201- dist = checkpoint_engine .distributed_hccl
201+ import checkpoint_engine .distributed_hccl as dist
202202 self ._device_type = "npu"
203203 elif device_type == "cuda" and self .device_manager .device_type == "cuda" :
204- import checkpoint_engine .distributed_nccl
205- dist = checkpoint_engine .distributed_nccl
204+ import checkpoint_engine .distributed_nccl as dist
206205 self ._device_type = "cuda"
207206 else :
208207 self ._device_type = "torch"
You can’t perform that action at this time.
0 commit comments