Skip to content

Commit b2e2b19

Browse files
authored
Set rpath with cmake for CUDA build (#2932)
1 parent ab4dce4 commit b2e2b19

4 files changed

Lines changed: 29 additions & 33 deletions

File tree

.github/actions/build-cuda-release/action.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,11 @@ runs:
2121
pip install auditwheel build patchelf setuptools
2222
python setup.py clean --all
2323
MLX_BUILD_STAGE=2 python -m build -w
24-
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}
24+
25+
auditwheel repair dist/* \
26+
--plat manylinux_2_35_${{ inputs.arch }} \
27+
--exclude libcublas* \
28+
--exclude libcuda* \
29+
--exclude libcudnn* \
30+
--exclude libnccl* \
31+
--exclude libnvrtc*

mlx/backend/cuda/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,18 @@ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
156156
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
157157
"${MLX_CUDA_ARCHITECTURES}")
158158

159+
if(MLX_BUILD_PYTHON_BINDINGS)
160+
set_property(
161+
TARGET mlx
162+
APPEND
163+
PROPERTY INSTALL_RPATH
164+
# The paths here should match the install_requires in setup.py.
165+
"$ORIGIN/../../nvidia/cublas/lib"
166+
"$ORIGIN/../../nvidia/cuda_nvrtc/lib"
167+
"$ORIGIN/../../nvidia/cudnn/lib"
168+
"$ORIGIN/../../nvidia/nccl/lib")
169+
endif()
170+
159171
# ------------------------ Dependencies ------------------------
160172

161173
# Use fixed version of CCCL.

python/scripts/repair_cuda.sh

Lines changed: 0 additions & 26 deletions
This file was deleted.

setup.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,22 +79,22 @@ def build_extension(self, ext: CMakeExtension) -> None:
7979
if not build_temp.exists():
8080
build_temp.mkdir(parents=True)
8181

82-
build_python = "ON"
83-
install_prefix = f"{extdir}{os.sep}"
82+
install_prefix = extdir
83+
pybind_out_dir = extdir
8484
if build_stage == 1:
8585
# Don't include MLX libraries in the wheel
86-
install_prefix = f"{build_temp}"
86+
install_prefix = build_temp
8787
elif build_stage == 2:
8888
# Don't include Python bindings in the wheel
89-
build_python = "OFF"
89+
pybind_out_dir = build_temp
9090
cmake_args = [
9191
f"-DCMAKE_INSTALL_PREFIX={install_prefix}",
92+
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={pybind_out_dir}",
9293
f"-DCMAKE_BUILD_TYPE={cfg}",
93-
f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}",
94+
"-DMLX_BUILD_PYTHON_BINDINGS=ON",
9495
"-DMLX_BUILD_TESTS=OFF",
9596
"-DMLX_BUILD_BENCHMARKS=OFF",
9697
"-DMLX_BUILD_EXAMPLES=OFF",
97-
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
9898
]
9999
if build_stage == 2 and build_cuda:
100100
# Last arch is always real and virtual for forward-compatibility
@@ -313,6 +313,9 @@ def get_tag(self) -> tuple[str, str, str]:
313313
elif build_cuda:
314314
toolkit = cuda_toolkit_major_version()
315315
name = f"mlx-cuda-{toolkit}"
316+
# Note: update following files when new dependency is added:
317+
# * .github/actions/build-cuda-release/action.yml
318+
# * mlx/backend/cuda/CMakeLists.txt
316319
if toolkit == 12:
317320
install_requires += [
318321
"nvidia-cublas-cu12==12.9.*",

0 commit comments

Comments
 (0)