-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathsetup.py
More file actions
150 lines (119 loc) · 4.13 KB
/
setup.py
File metadata and controls
150 lines (119 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""ffpa-attn build script.
Most package metadata (name, version, dependencies, extras, URLs, etc.) lives
in ``pyproject.toml``. This ``setup.py`` exists only to drive the optional
``CUDAExtension`` build for the ``ffpa_attn._C`` C++/CUDA module via the
PyTorch build helpers.
Behavior:
- Default: pure Python / Triton-only build with no CUDA extension.
- ``ENABLE_FFPA_CUDA_IMPL=1``: build the optional ``ffpa_attn._C``
CUDA extension via ``torch.utils.cpp_extension``.
- ``FFPA_SKIP_CUDA_EXT=1``: force-skip CUDA extension even if the CUDA build
flags are enabled (used by docs CI and similar environments).
"""
import os
import subprocess
import sys
import warnings
from pathlib import Path
from setuptools import setup
warnings.filterwarnings("ignore")
# Ensure the project root (containing ``env.py``) is on ``sys.path`` so that
# the build backend (``setuptools.build_meta``) can import ``env`` even when
# pip invokes setup.py from a different working directory.
_ROOT = Path(__file__).resolve().parent
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
_VERSION_FILE = _ROOT / "src" / "ffpa_attn" / "_version.py"
def _resolve_version() -> str:
try:
from setuptools_scm import get_version as _get_scm_version
return _get_scm_version(
root=str(_ROOT),
version_scheme="python-simplified-semver",
local_scheme="no-local-version",
write_to=str(_VERSION_FILE),
)
except Exception:
pass
try:
describe = subprocess.check_output(
["git", "describe", "--tags", "--long", "--dirty", "--match", "v*"],
cwd=_ROOT,
text=True,
stderr=subprocess.DEVNULL,
).strip()
tag, distance, _sha, *dirty = describe.split("-")
base = tag.removeprefix("v")
if distance == "0" and not dirty:
return base
return f"{base}.dev{distance}"
except Exception:
return "0.0.0"
def _write_version_file(version: str) -> None:
_VERSION_FILE.write_text(
'# file generated by setup.py\n'
f'version = "{version}"\n'
'__version__ = version\n',
encoding="utf-8",
)
_RESOLVED_VERSION = _resolve_version()
if not _VERSION_FILE.exists():
_write_version_file(_RESOLVED_VERSION)
def _env_flag(name: str) -> bool:
return os.getenv(name, "0").strip().lower() in {"1", "true", "yes", "on"}
SKIP_CUDA_EXT = _env_flag("FFPA_SKIP_CUDA_EXT")
from env import ENV
BUILD_CUDA_EXT = (not SKIP_CUDA_EXT) and ENV.enable_fwd_cuda_impl()
ext_modules = []
cmdclass = {}
if BUILD_CUDA_EXT:
# Force a PyPI-acceptable platform tag for the produced wheel only when the
# binary CUDA extension is actually built.
_PLAT_TAG = os.getenv("FFPA_PLAT_TAG", "manylinux_2_34_x86_64").strip()
if _PLAT_TAG:
try:
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
class _ManylinuxBdistWheel(_bdist_wheel):
def finalize_options(self):
super().finalize_options()
self.root_is_pure = False
self.plat_name_supplied = True
self.plat_name = _PLAT_TAG
cmdclass["bdist_wheel"] = _ManylinuxBdistWheel
except ImportError:
pass
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
ENV.list_ffpa_env()
cc_flag = []
for _sm in ENV.get_build_arch_list():
cc_flag.append("-gencode")
cc_flag.append(f"arch=compute_{_sm},code=sm_{_sm}")
ext_modules.append(
CUDAExtension(
# Package-internal C extension module; imported as ``ffpa_attn._C``.
name="ffpa_attn._C",
sources=[
# Convert to repo-relative paths; setuptools rejects absolute paths
# in ``sources`` for editable installs (``pip install -e .``).
os.path.relpath(s, _ROOT)
for s in ENV.get_build_sources(build_pkg=True)
],
extra_compile_args={
"cxx": [flag for flag in ENV.extra_gcc_flags() if flag.strip()],
"nvcc": [
flag
for flag in (ENV.get_build_cuda_cflags(build_pkg=True) + cc_flag)
if flag.strip()
],
},
include_dirs=[
Path(ENV.project_dir()) / "csrc" / "cuffpa",
],
)
)
cmdclass["build_ext"] = BuildExtension
setup(
version=_RESOLVED_VERSION,
ext_modules=ext_modules,
cmdclass=cmdclass,
)