From c56938854631dbe296ad1a93ae3c6d8aa4cf5caa Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 18 Dec 2025 10:59:49 +0000 Subject: [PATCH] Set rpath with cmake for CUDA build --- .github/actions/build-cuda-release/action.yml | 9 ++++++- mlx/backend/cuda/CMakeLists.txt | 12 +++++++++ python/scripts/repair_cuda.sh | 26 ------------------- setup.py | 15 ++++++----- 4 files changed, 29 insertions(+), 33 deletions(-) delete mode 100644 python/scripts/repair_cuda.sh diff --git a/.github/actions/build-cuda-release/action.yml b/.github/actions/build-cuda-release/action.yml index 1f5ab515c7..9f7558818c 100644 --- a/.github/actions/build-cuda-release/action.yml +++ b/.github/actions/build-cuda-release/action.yml @@ -21,4 +21,11 @@ runs: pip install auditwheel build patchelf setuptools python setup.py clean --all MLX_BUILD_STAGE=2 python -m build -w - bash python/scripts/repair_cuda.sh ${{ inputs.arch }} + + auditwheel repair dist/* \ + --plat manylinux_2_35_${{ inputs.arch }} \ + --exclude libcublas* \ + --exclude libcuda* \ + --exclude libcudnn* \ + --exclude libnccl* \ + --exclude libnvrtc* diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 492c13533f..26ca4773f1 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -156,6 +156,18 @@ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}") +if(MLX_BUILD_PYTHON_BINDINGS) + set_property( + TARGET mlx + APPEND + PROPERTY INSTALL_RPATH + # The paths here should match the install_requires in setup.py. + "$ORIGIN/../../nvidia/cublas/lib" + "$ORIGIN/../../nvidia/cuda_nvrtc/lib" + "$ORIGIN/../../nvidia/cudnn/lib" + "$ORIGIN/../../nvidia/nccl/lib") +endif() + # ------------------------ Dependencies ------------------------ # Use fixed version of CCCL. diff --git a/python/scripts/repair_cuda.sh b/python/scripts/repair_cuda.sh deleted file mode 100644 index 187b0b15ff..0000000000 --- a/python/scripts/repair_cuda.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/bash - -auditwheel repair dist/* \ - --plat manylinux_2_35_${1} \ - --exclude libcublas* \ - --exclude libnvrtc* \ - --exclude libcuda* \ - --exclude libcudnn* \ - --exclude libnccl* \ - -w wheel_tmp - - -mkdir wheelhouse -cd wheel_tmp -repaired_wheel=$(find . -name "*.whl" -print -quit) -unzip -q "${repaired_wheel}" -rm "${repaired_wheel}" -mlx_so="mlx/lib/libmlx.so" -rpath=$(patchelf --print-rpath "${mlx_so}") -base="\$ORIGIN/../../nvidia" -rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib:${base}/nccl/lib -patchelf --force-rpath --set-rpath "$rpath" "$mlx_so" -python ../python/scripts/repair_record.py ${mlx_so} - -# Re-zip the repaired wheel -zip -r -q "../wheelhouse/${repaired_wheel}" . diff --git a/setup.py b/setup.py index 2c114d945c..2d2c139339 100644 --- a/setup.py +++ b/setup.py @@ -79,22 +79,22 @@ def build_extension(self, ext: CMakeExtension) -> None: if not build_temp.exists(): build_temp.mkdir(parents=True) - build_python = "ON" - install_prefix = f"{extdir}{os.sep}" + install_prefix = extdir + pybind_out_dir = extdir if build_stage == 1: # Don't include MLX libraries in the wheel - install_prefix = f"{build_temp}" + install_prefix = build_temp elif build_stage == 2: # Don't include Python bindings in the wheel - build_python = "OFF" + pybind_out_dir = build_temp cmake_args = [ f"-DCMAKE_INSTALL_PREFIX={install_prefix}", + f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={pybind_out_dir}", f"-DCMAKE_BUILD_TYPE={cfg}", - f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}", + "-DMLX_BUILD_PYTHON_BINDINGS=ON", "-DMLX_BUILD_TESTS=OFF", "-DMLX_BUILD_BENCHMARKS=OFF", "-DMLX_BUILD_EXAMPLES=OFF", - f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}", ] if build_stage == 2 and build_cuda: # Last arch is always real and virtual for forward-compatibility @@ -313,6 +313,9 @@ def get_tag(self) -> tuple[str, str, str]: elif build_cuda: toolkit = cuda_toolkit_major_version() name = f"mlx-cuda-{toolkit}" + # Note: update following files when new dependency is added: + # * .github/actions/build-cuda-release/action.yml + # * mlx/backend/cuda/CMakeLists.txt if toolkit == 12: install_requires += [ "nvidia-cublas-cu12==12.9.*",