Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/actions/build-cuda-release/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,11 @@ runs:
pip install auditwheel build patchelf setuptools
python setup.py clean --all
MLX_BUILD_STAGE=2 python -m build -w
bash python/scripts/repair_cuda.sh ${{ inputs.arch }}

auditwheel repair dist/* \
--plat manylinux_2_35_${{ inputs.arch }} \
--exclude libcublas* \
--exclude libcuda* \
--exclude libcudnn* \
--exclude libnccl* \
--exclude libnvrtc*
12 changes: 12 additions & 0 deletions mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,18 @@ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
"${MLX_CUDA_ARCHITECTURES}")

if(MLX_BUILD_PYTHON_BINDINGS)
set_property(
TARGET mlx
APPEND
PROPERTY INSTALL_RPATH
# The paths here should match the install_requires in setup.py.
"$ORIGIN/../../nvidia/cublas/lib"
"$ORIGIN/../../nvidia/cuda_nvrtc/lib"
"$ORIGIN/../../nvidia/cudnn/lib"
"$ORIGIN/../../nvidia/nccl/lib")
endif()
Comment on lines +159 to +169
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm building python locally and using the system installed libraries that will still work right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it still works, setting RPATH does not remove other lib search paths.


# ------------------------ Dependencies ------------------------

# Use fixed version of CCCL.
Expand Down
26 changes: 0 additions & 26 deletions python/scripts/repair_cuda.sh

This file was deleted.

15 changes: 9 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,22 @@ def build_extension(self, ext: CMakeExtension) -> None:
if not build_temp.exists():
build_temp.mkdir(parents=True)

build_python = "ON"
install_prefix = f"{extdir}{os.sep}"
install_prefix = extdir
pybind_out_dir = extdir
if build_stage == 1:
# Don't include MLX libraries in the wheel
install_prefix = f"{build_temp}"
install_prefix = build_temp
elif build_stage == 2:
# Don't include Python bindings in the wheel
build_python = "OFF"
pybind_out_dir = build_temp
cmake_args = [
f"-DCMAKE_INSTALL_PREFIX={install_prefix}",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={pybind_out_dir}",
f"-DCMAKE_BUILD_TYPE={cfg}",
f"-DMLX_BUILD_PYTHON_BINDINGS={build_python}",
"-DMLX_BUILD_PYTHON_BINDINGS=ON",
"-DMLX_BUILD_TESTS=OFF",
"-DMLX_BUILD_BENCHMARKS=OFF",
"-DMLX_BUILD_EXAMPLES=OFF",
f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={extdir}{os.sep}",
]
if build_stage == 2 and build_cuda:
# Last arch is always real and virtual for forward-compatibility
Expand Down Expand Up @@ -313,6 +313,9 @@ def get_tag(self) -> tuple[str, str, str]:
elif build_cuda:
toolkit = cuda_toolkit_major_version()
name = f"mlx-cuda-{toolkit}"
# Note: update following files when new dependency is added:
# * .github/actions/build-cuda-release/action.yml
# * mlx/backend/cuda/CMakeLists.txt
if toolkit == 12:
install_requires += [
"nvidia-cublas-cu12==12.9.*",
Expand Down
Loading