@@ -79,22 +79,22 @@ def build_extension(self, ext: CMakeExtension) -> None:
7979 if not build_temp .exists ():
8080 build_temp .mkdir (parents = True )
8181
82- build_python = "ON"
83- install_prefix = f" { extdir } { os . sep } "
82+ install_prefix = extdir
83+ pybind_out_dir = extdir
8484 if build_stage == 1 :
8585 # Don't include MLX libraries in the wheel
86- install_prefix = f" { build_temp } "
86+ install_prefix = build_temp
8787 elif build_stage == 2 :
8888 # Don't include Python bindings in the wheel
89- build_python = "OFF"
89+ pybind_out_dir = build_temp
9090 cmake_args = [
9191 f"-DCMAKE_INSTALL_PREFIX={ install_prefix } " ,
92+ f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={ pybind_out_dir } " ,
9293 f"-DCMAKE_BUILD_TYPE={ cfg } " ,
93- f "-DMLX_BUILD_PYTHON_BINDINGS={ build_python } " ,
94+ "-DMLX_BUILD_PYTHON_BINDINGS=ON " ,
9495 "-DMLX_BUILD_TESTS=OFF" ,
9596 "-DMLX_BUILD_BENCHMARKS=OFF" ,
9697 "-DMLX_BUILD_EXAMPLES=OFF" ,
97- f"-DMLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY={ extdir } { os .sep } " ,
9898 ]
9999 if build_stage == 2 and build_cuda :
100100 # Last arch is always real and virtual for forward-compatibility
@@ -313,6 +313,9 @@ def get_tag(self) -> tuple[str, str, str]:
313313 elif build_cuda :
314314 toolkit = cuda_toolkit_major_version ()
315315 name = f"mlx-cuda-{ toolkit } "
316+ # Note: update following files when new dependency is added:
317+ # * .github/actions/build-cuda-release/action.yml
318+ # * mlx/backend/cuda/CMakeLists.txt
316319 if toolkit == 12 :
317320 install_requires += [
318321 "nvidia-cublas-cu12==12.9.*" ,
0 commit comments