Skip to content

Commit b701319

Browse files
authored
Add RMM as a dependency of the GPU code (#753)
1 parent a864650 commit b701319

6 files changed

Lines changed: 31 additions & 11 deletions

File tree

implicit/gpu/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ else()
1616
add_compile_options(-DCCCL_IGNORE_DEPRECATED_STREAM_REF_HEADER)
1717

1818
# use rapids-cmake to install dependencies
19-
set(rapids-cmake-version "25.02")
20-
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/refs/heads/branch-25.02/RAPIDS.cmake
19+
set(rapids-cmake-version "26.02")
20+
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/refs/heads/release/26.02/RAPIDS.cmake
2121
${CMAKE_BINARY_DIR}/RAPIDS.cmake)
2222
include(${CMAKE_BINARY_DIR}/RAPIDS.cmake)
2323
include(rapids-cmake)
@@ -37,11 +37,11 @@ else()
3737
include(${rapids-cmake-dir}/cpm/rmm.cmake)
3838
rapids_cpm_rmm()
3939

40-
rapids_cpm_find(raft 25.02
40+
rapids_cpm_find(raft 26.02
4141
GLOBAL_TARGETS raft::raft
4242
CPM_ARGS
4343
GIT_REPOSITORY https://github.com/rapidsai/raft.git
44-
GIT_TAG v25.02.00
44+
GIT_TAG v26.02.00
4545
SOURCE_SUBDIR cpp
4646
OPTIONS
4747
"BUILD_TESTS OFF"

implicit/gpu/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,28 @@
33
import warnings
44

55
HAS_CUDA = False
6+
HAS_RMM = False
7+
68
try:
9+
# RMM is required to enable GPU support - install with 'pip install rmm-cu13'
10+
# Note that we need to import rmm here before importing the _cuda.so extension
11+
import rmm # noqa
12+
13+
HAS_RMM = True
14+
715
from ._cuda import * # noqa
816

917
get_device_count() # noqa pylint: disable=undefined-variable
1018
HAS_CUDA = True
1119

1220
except RuntimeError as e:
21+
import warnings
22+
1323
warnings.warn(
1424
f"CUDA extension is built, but disabling GPU support because of '{e}'",
1525
)
1626
except ImportError as e:
17-
warnings.warn(
18-
f"Disabling GPU support because of '{e}'",
19-
)
27+
if HAS_RMM:
28+
warnings.warn(
29+
f"Disabling GPU support because of '{e}'",
30+
)

implicit/gpu/als.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __init__(
5555
calculate_training_loss=False,
5656
random_state=None,
5757
):
58+
if not implicit.gpu.HAS_RMM:
59+
raise ValueError("RMM isn't installed, can't train on GPU.")
60+
5861
if not implicit.gpu.HAS_CUDA:
5962
raise ValueError("No CUDA extension has been built, can't train on GPU.")
6063

implicit/gpu/bpr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def __init__(
5656
random_state=None,
5757
):
5858
super().__init__()
59+
if not implicit.gpu.HAS_RMM:
60+
raise ValueError("RMM isn't installed, can't train on GPU.")
61+
5962
if not implicit.gpu.HAS_CUDA:
6063
raise ValueError("No CUDA extension has been built, can't train on GPU.")
6164

implicit/gpu/knn.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#include <cub/device/device_segmented_radix_sort.cuh>
66
#include <cuda_runtime.h>
77
#include <rmm/cuda_stream_view.hpp>
8-
#include <rmm/mr/device/cuda_memory_resource.hpp>
9-
#include <rmm/mr/device/device_memory_resource.hpp>
10-
#include <rmm/mr/device/pool_memory_resource.hpp>
8+
#include <rmm/mr/cuda_memory_resource.hpp>
9+
#include <rmm/mr/device_memory_resource.hpp>
10+
#include <rmm/mr/pool_memory_resource.hpp>
1111
#include <thrust/device_ptr.h>
1212
#include <thrust/functional.h>
1313
#include <thrust/iterator/constant_iterator.h>

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ classifiers = [
3131
]
3232
dependencies = ["numpy>=1.17.0", "scipy>=0.16", "tqdm>=4.27", "threadpoolctl"]
3333

34+
[project.optional-dependencies]
35+
gpu = ["rmm-cu13"]
36+
3437
[tool.scikit-build]
35-
wheel.exclude = ["**.pyx", "**.cmake", "**.cuh", "**.hpp", "**.h", "**.hpp.in"]
38+
wheel.exclude = ["**.pyx", "**.cmake", "**.cuh", "**.hpp", "**.h", "**.hpp.in", "**librapids_logger.so", "**librmm.so"]
3639
cmake.version = ">=3.21.0"
3740
ninja.version = ">=1.10"
3841
build-dir = "build"

0 commit comments

Comments
 (0)