-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsetup.py
More file actions
56 lines (44 loc) · 1.67 KB
/
setup.py
File metadata and controls
56 lines (44 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from pathlib import Path
import sys
from setuptools import setup
def is_metadata_command() -> bool:
metadata_commands = {"egg_info", "dist_info", "sdist", "clean", "--name", "--version"}
return any(arg in metadata_commands for arg in sys.argv[1:])
def get_extensions():
if is_metadata_command():
return [], {}
try:
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
except Exception as exc: # pragma: no cover - exercised during package build only.
raise RuntimeError(
"torch-cudagraph-debug requires PyTorch at build time. Install a CUDA-enabled "
"PyTorch first, then install this package with build isolation disabled if needed."
) from exc
if not torch.cuda._is_compiled():
raise RuntimeError(
"torch-cudagraph-debug must be built against a CUDA-enabled PyTorch installation."
)
root = Path(__file__).parent
csrc = root / "src" / "torch_cudagraph_debug" / "csrc"
sources = [
csrc / "bindings.cpp",
csrc / "tensor_debug" / "probe_context.cpp",
csrc / "tensor_debug" / "tensor_format.cpp",
csrc / "tensor_debug" / "compare.cpp",
]
extension = CUDAExtension(
name="torch_cudagraph_debug._C",
sources=[str(path.relative_to(root)) for path in sources],
include_dirs=[str(csrc)],
extra_compile_args={
"cxx": ["-O3", "-std=c++17", "-Wall"],
"nvcc": ["-O3", "-std=c++17"],
},
)
return [extension], {"build_ext": BuildExtension}
ext_modules, cmdclass = get_extensions()
setup(
ext_modules=ext_modules,
cmdclass=cmdclass,
)