forked from foundation-model-stack/fastsafetensors
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
62 lines (50 loc) · 2.05 KB
/
setup.py
File metadata and controls
62 lines (50 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# SPDX-License-Identifier: Apache-2.0
import os
import platform
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
def MyExtension(name, sources, mod_name, *args, **kwargs):
import pybind11
pybind11_path = os.path.dirname(pybind11.__file__)
kwargs["define_macros"] = [("__MOD_NAME__", mod_name)]
kwargs["libraries"] = ["stdc++"]
kwargs["include_dirs"] = kwargs.get("include_dirs", []) + [
f"{pybind11_path}/include"
]
kwargs["language"] = "c++"
kwargs["extra_compile_args"] = ["-fvisibility=hidden", "-std=c++17"]
# Windows-specific configuration for DirectStorage + D3D12/CUDA interop
if platform.system() == "Windows":
sources.append("fastsafetensors/cpp/dstorage_reader.cpp")
kwargs["libraries"] = []
# c++20 required for designated initializers at ext.hpp
kwargs["extra_compile_args"] = ["/std:c++20"]
# Note: dstorage.dll is loaded at runtime via LoadLibrary, not linked.
kwargs["libraries"].extend(["ole32", "d3d12", "dxgi", "dxguid", "uuid"])
# CUDA interop headers: if CUDA_HOME/CUDA_PATH is set, add include path
# for cudaExternalMemory types used by the interop bridge.
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
if cuda_home:
cuda_include = os.path.join(cuda_home, "include")
if os.path.isdir(cuda_include):
kwargs["include_dirs"].append(cuda_include)
return Extension(name, sources, *args, **kwargs)
package_data_patterns = ["*.hpp", "*.h", "cpp.pyi"]
setup(
packages=[
"fastsafetensors",
"fastsafetensors.copier",
"fastsafetensors.cpp",
"fastsafetensors.frameworks",
],
include_package_data=True,
package_data={"fastsafetensors.cpp": package_data_patterns},
ext_modules=[
MyExtension(
name="fastsafetensors.cpp",
sources=["fastsafetensors/cpp/ext.cpp"],
include_dirs=["fastsafetensors/cpp"],
mod_name="cpp",
)
],
)