-
Notifications
You must be signed in to change notification settings - Fork 678
Expand file tree
/
Copy pathpynccl.py
More file actions
78 lines (60 loc) · 1.9 KB
/
pynccl.py
File metadata and controls
78 lines (60 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations
import functools
import os
import pathlib
from typing import TYPE_CHECKING, Any, Literal
from minisgl.env import ENV
from .utils import load_aot
if TYPE_CHECKING:
from abc import abstractmethod
import torch
from tvm_ffi import Module
class PyNCCLCommunicator:
@abstractmethod
def all_reduce(self, input: torch.Tensor, op: Literal["sum"]) -> None: ...
@abstractmethod
def all_gather(self, output: torch.Tensor, input: torch.Tensor) -> None: ...
@abstractmethod
def get_buffer(self) -> int: ...
else:
PyNCCLCommunicator = Any
@functools.cache
def _load_nccl_module() -> Module:
return load_aot("pynccl", cuda_files=["pynccl.cu"], extra_ldflags=["-lnccl"])
@functools.cache
def _get_pynccl_wrapper_cls():
import tvm_ffi
@tvm_ffi.register_object("minisgl.NCCLWrapper")
class PyNCCLImpl(tvm_ffi.Object):
def __init__(self, *args):
self.__ffi_init__(*args)
return PyNCCLImpl
def init_pynccl(
*,
tp_rank: int,
tp_size: int,
tp_cpu_group: torch.distributed.ProcessGroup,
max_size_bytes: int = 0,
) -> PyNCCLCommunicator:
import torch
max_size_bytes = min(max_size_bytes, ENV.PYNCCL_MAX_BUFFER_SIZE.value)
module = _load_nccl_module()
cls = _get_pynccl_wrapper_cls()
if tp_rank == 0:
id_list = [module.create_nccl_uid()]
torch.distributed.broadcast_object_list(
id_list,
src=0,
group=tp_cpu_group,
)
else:
id_list = [None]
torch.distributed.broadcast_object_list(
id_list,
src=0,
group=tp_cpu_group,
)
nccl_id = id_list[0]
assert not nccl_id is None, f"Failed to get NCCL unique ID on {tp_rank = }"
# bypass type checking for the FFI object
return cls(tp_rank, tp_size, max_size_bytes, nccl_id) # type: ignore