Skip to content

Commit 1a5f667

Browse files
authored
Migrate RMM usage to CCCL MR design (#2996)
## Summary - Remove `device_memory_resource` base class usage, de-template all resource and adaptor types, replace pointer-based per-device resource APIs with ref-based equivalents - Part of rapidsai/rmm#2011. Migration guide: rapidsai/rmm#2344. - Supersedes #2917 and #2920 Depends on rapidsai/rmm#2361. Depends on rapidsai/ucxx#636. ## Changes ### Core resource infrastructure - **`device_memory_resource.hpp`**: Remove `any_resource_bridge` (which inherited from `rmm::mr::device_memory_resource`), remove all `shared_ptr<device_memory_resource>` constructor overloads, consolidate to `any_resource`-only path - **`device_resources.hpp`**: Remove deprecated constructor taking `shared_ptr<device_memory_resource>`, update `get_workspace_resource()` return type (de-templated `limiting_resource_adaptor`) - **`device_resources_snmg.hpp`**: Remove stale include, de-template `pool_memory_resource` - **`handle.hpp`**: Remove deprecated constructors taking `shared_ptr<device_memory_resource>` - **`device_resources_manager.hpp`**: Retype `workspace_mrs` vector from `shared_ptr<device_memory_resource>` to `raft::mr::device_resource`, update `set_workspace_memory_resource()` signature accordingly, de-template `pool_mr_` to `optional<pool_memory_resource>`, remove `dynamic_cast` for upstream type detection, replace `get/set_current_device_resource()` with `_ref` variants ### Memory tracking - **`memory_tracking_resources.hpp`**: Remove `device_tracking_bridge` (inherited from `device_memory_resource`), use `set_current_device_resource_ref()` directly ### Call sites using `get_workspace_resource()` → `get_workspace_resource_ref()` - `select_k-inl.cuh`, `select_radix.cuh`, `select_warpsort.cuh`, `sparse/select_k-inl.cuh`, `bitmap_to_csr.cuh`, `bitset_to_csr.cuh` ### Benchmarks - **`benchmark.hpp`**: De-template `pool_memory_resource`, use `any_resource` for RAII restore - **`gather.cu`**, **`subsample.cu`**: Same pattern ### Tests - **`handle.cpp`**: Dereference `limiting_resource_adaptor*` for `device_buffer` constructor - **`device_resources_manager.cpp`**: Remove workspace-related test code for removed APIs - **`mdarray.cu`**: Remove `test_device_resource_bridge_unwrap` (bridge no longer exists) - **`multi_variable_gaussian.cu`**: `get_current_device_resource()` → `get_current_device_resource_ref()` Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Divye Gala (https://github.com/divyegala) URL: #2996
1 parent eab1539 commit 1a5f667

19 files changed

Lines changed: 129 additions & 425 deletions

cpp/bench/prims/common/benchmark.hpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <rmm/cuda_stream.hpp>
1818
#include <rmm/cuda_stream_view.hpp>
1919
#include <rmm/device_buffer.hpp>
20-
#include <rmm/mr/device_memory_resource.hpp>
20+
#include <rmm/mr/cuda_memory_resource.hpp>
2121
#include <rmm/mr/per_device_resource.hpp>
2222
#include <rmm/mr/pool_memory_resource.hpp>
2323

@@ -33,26 +33,23 @@ namespace raft::bench {
3333
*/
3434
struct using_pool_memory_res {
3535
private:
36-
rmm::mr::device_memory_resource* orig_res_;
37-
rmm::mr::cuda_memory_resource cuda_res_{};
38-
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;
36+
rmm::mr::pool_memory_resource pool_res_;
37+
cuda::mr::any_resource<cuda::mr::device_accessible> prev_res_;
3938

4039
public:
4140
using_pool_memory_res(size_t initial_size, size_t max_size)
42-
: orig_res_(rmm::mr::get_current_device_resource()),
43-
pool_res_(&cuda_res_, initial_size, max_size)
41+
: pool_res_(rmm::mr::cuda_memory_resource{}, initial_size, max_size),
42+
prev_res_(rmm::mr::set_current_device_resource(pool_res_))
4443
{
45-
rmm::mr::set_current_device_resource(&pool_res_);
4644
}
4745

4846
using_pool_memory_res()
49-
: orig_res_(rmm::mr::get_current_device_resource()),
50-
pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50))
47+
: pool_res_(rmm::mr::cuda_memory_resource{}, rmm::percent_of_free_device_memory(50)),
48+
prev_res_(rmm::mr::set_current_device_resource(pool_res_))
5149
{
52-
rmm::mr::set_current_device_resource(&pool_res_);
5350
}
5451

55-
~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); }
52+
~using_pool_memory_res() { rmm::mr::set_current_device_resource(prev_res_); }
5653
};
5754

5855
/**

cpp/bench/prims/matrix/gather.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -13,7 +13,7 @@
1313
#include <raft/util/itertools.hpp>
1414

1515
#include <rmm/device_uvector.hpp>
16-
#include <rmm/mr/device_memory_resource.hpp>
16+
#include <rmm/mr/per_device_resource.hpp>
1717
#include <rmm/mr/pool_memory_resource.hpp>
1818

1919
namespace raft::bench::matrix {
@@ -35,18 +35,17 @@ template <typename T, typename MapT, typename IdxT, bool Conditional = false>
3535
struct Gather : public fixture {
3636
Gather(const GatherParams<IdxT>& p)
3737
: params(p),
38-
old_mr(rmm::mr::get_current_device_resource()),
39-
pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)),
38+
pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * (1ULL << 30)),
39+
prev_res_(rmm::mr::set_current_device_resource(pool_mr)),
4040
matrix(this->handle),
4141
map(this->handle),
4242
out(this->handle),
4343
stencil(this->handle),
4444
matrix_h(this->handle)
4545
{
46-
rmm::mr::set_current_device_resource(&pool_mr);
4746
}
4847

49-
~Gather() { rmm::mr::set_current_device_resource(old_mr); }
48+
~Gather() { rmm::mr::set_current_device_resource(prev_res_); }
5049

5150
void allocate_data(const ::benchmark::State& state) override
5251
{
@@ -107,8 +106,8 @@ struct Gather : public fixture {
107106

108107
private:
109108
GatherParams<IdxT> params;
110-
rmm::mr::device_memory_resource* old_mr;
111-
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
109+
rmm::mr::pool_memory_resource pool_mr;
110+
cuda::mr::any_resource<cuda::mr::device_accessible> prev_res_;
112111
raft::device_matrix<T, IdxT> matrix, out;
113112
raft::host_matrix<T, IdxT> matrix_h;
114113
raft::device_vector<T, IdxT> stencil;

cpp/bench/prims/random/subsample.cu

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,15 @@ template <typename T>
5050
struct sample : public fixture {
5151
sample(const sample_inputs& p)
5252
: params(p),
53-
old_mr(rmm::mr::get_current_device_resource()),
54-
pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB),
53+
pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * GiB),
54+
prev_mr(rmm::mr::set_current_device_resource(pool_mr)),
5555
in(make_device_vector<T, int64_t>(res, p.n_samples)),
5656
out(make_device_vector<T, int64_t>(res, p.n_train))
5757
{
58-
rmm::mr::set_current_device_resource(&pool_mr);
5958
raft::random::RngState r(123456ULL);
6059
}
6160

62-
~sample() { rmm::mr::set_current_device_resource(old_mr); }
61+
~sample() { rmm::mr::set_current_device_resource(prev_mr); }
6362
void run_benchmark(::benchmark::State& state) override
6463
{
6564
std::ostringstream label_stream;
@@ -81,8 +80,8 @@ struct sample : public fixture {
8180
private:
8281
float GiB = 1073741824.0f;
8382
raft::device_resources res;
84-
rmm::mr::device_memory_resource* old_mr;
85-
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
83+
rmm::mr::pool_memory_resource pool_mr;
84+
cuda::mr::any_resource<cuda::mr::device_accessible> prev_mr;
8685
sample_inputs params;
8786
raft::device_vector<T, int64_t> out, in;
8887
}; // struct sample

cpp/include/raft/core/device_resources.hpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -26,7 +26,7 @@
2626

2727
#include <rmm/cuda_stream_pool.hpp>
2828
#include <rmm/exec_policy.hpp>
29-
#include <rmm/mr/device_memory_resource.hpp>
29+
#include <rmm/resource_ref.hpp>
3030

3131
#include <cuda_runtime.h>
3232

@@ -52,12 +52,11 @@ namespace raft {
5252
class device_resources : public resources {
5353
public:
5454
device_resources(const device_resources& handle,
55-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource,
55+
raft::mr::device_resource workspace_resource,
5656
std::optional<std::size_t> allocation_limit = std::nullopt)
5757
: resources{handle}
5858
{
59-
// replace the resource factory for the workspace_resources
60-
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
59+
resource::set_workspace_resource(*this, std::move(workspace_resource), allocation_limit);
6160
}
6261

6362
device_resources(const device_resources& handle) : resources{handle} {}
@@ -77,8 +76,8 @@ class device_resources : public resources {
7776
*/
7877
device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
7978
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr},
80-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource = {nullptr},
81-
std::optional<std::size_t> allocation_limit = std::nullopt)
79+
std::optional<raft::mr::device_resource> workspace_resource = std::nullopt,
80+
std::optional<std::size_t> allocation_limit = std::nullopt)
8281
: resources{}
8382
{
8483
resources::add_resource_factory(std::make_shared<resource::device_id_resource_factory>());
@@ -87,7 +86,7 @@ class device_resources : public resources {
8786
resources::add_resource_factory(
8887
std::make_shared<resource::cuda_stream_pool_resource_factory>(stream_pool));
8988
if (workspace_resource) {
90-
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
89+
resource::set_workspace_resource(*this, std::move(*workspace_resource), allocation_limit);
9190
}
9291
}
9392

@@ -214,7 +213,7 @@ class device_resources : public resources {
214213
return resource::get_subcomm(*this, key);
215214
}
216215

217-
rmm::mr::device_memory_resource* get_workspace_resource() const
216+
rmm::mr::limiting_resource_adaptor* get_workspace_resource() const
218217
{
219218
return resource::get_workspace_resource(*this);
220219
}

cpp/include/raft/core/device_resources_manager.hpp

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -12,6 +12,7 @@
1212
#include <rmm/cuda_stream_pool.hpp>
1313
#include <rmm/mr/cuda_memory_resource.hpp>
1414
#include <rmm/mr/per_device_resource.hpp>
15+
#include <rmm/mr/pool_memory_resource.hpp>
1516

1617
#include <algorithm>
1718
#include <memory>
@@ -117,9 +118,7 @@ struct device_resources_manager {
117118
// Optional specification of separate workspace memory resources for each
118119
// device. The integer in each pair indicates the device for this memory
119120
// resource.
120-
std::vector<std::pair<std::shared_ptr<rmm::mr::device_memory_resource>, int>> workspace_mrs{};
121-
122-
auto get_workspace_memory_resource(int device_id) {}
121+
std::vector<std::pair<raft::mr::device_resource, int>> workspace_mrs{};
123122
} params_;
124123

125124
// This struct stores the underlying resources to be shared among
@@ -152,36 +151,31 @@ struct device_resources_manager {
152151
}()},
153152
pool_mr_{[&params, this]() {
154153
auto scoped_device = device_setter{device_id_};
155-
auto result =
156-
std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>{nullptr};
154+
auto result = std::optional<rmm::mr::pool_memory_resource>{};
157155
// If max_mem_pool_size is nullopt or non-zero, create a pool memory
158156
// resource
159157
if (params.max_mem_pool_size.value_or(1) != 0) {
160-
auto* upstream =
161-
dynamic_cast<rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource());
162-
if (upstream != nullptr) {
163-
result =
164-
std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
165-
upstream,
166-
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
167-
params.max_mem_pool_size);
168-
rmm::mr::set_current_device_resource(result.get());
169-
} else {
170-
RAFT_LOG_WARN(
171-
"Pool allocation requested, but other memory resource has already been set and "
172-
"will not be overwritten");
173-
}
158+
// TODO: reinstate the dynamic_cast<cuda_memory_resource*> guard that
159+
// skipped pool creation when a non-default resource was already set,
160+
// once CCCL exposes resource_cast or an equivalent type-query.
161+
auto upstream = rmm::mr::get_current_device_resource_ref();
162+
result.emplace(
163+
upstream,
164+
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
165+
params.max_mem_pool_size);
166+
rmm::mr::set_current_device_resource(*result);
174167
}
175168
return result;
176169
}()},
177170
workspace_mr_{[&params, this]() {
178-
auto result = std::shared_ptr<rmm::mr::device_memory_resource>{nullptr};
171+
auto result = std::optional<raft::mr::device_resource>{};
179172
auto iter = std::find_if(std::begin(params.workspace_mrs),
180173
std::end(params.workspace_mrs),
181174
[this](auto&& pair) { return pair.second == device_id_; });
182175
if (iter != std::end(params.workspace_mrs)) { result = iter->first; }
183176
return result;
184-
}()}
177+
}()},
178+
workspace_allocation_limit_{params.workspace_allocation_limit}
185179
{
186180
}
187181

@@ -216,26 +210,23 @@ struct device_resources_manager {
216210
if (pool_count() != 0) { result = pools_[get_thread_id() % pool_count()]; }
217211
return result;
218212
}
219-
// Return a (possibly null) shared_ptr to the pool memory resource
220-
// created for this device by the manager
221-
[[nodiscard]] auto get_pool_memory_resource() const { return pool_mr_; }
213+
// Return the pool memory resource created for this device by the manager (if any)
214+
[[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; }
222215
// Return the RAFT workspace allocation limit that will be used by
223216
// `device_resources` returned from this manager
224217
[[nodiscard]] auto get_workspace_allocation_limit() const
225218
{
226219
return workspace_allocation_limit_;
227220
}
228-
// Return a (possibly null) shared_ptr to the memory resource that will
229-
// be used for workspace allocations by `device_resources` returned from
230-
// this manager
231-
[[nodiscard]] auto get_workspace_memory_resource() { return workspace_mr_; }
221+
// Return the workspace memory resource for this device (if any)
222+
[[nodiscard]] auto& get_workspace_memory_resource() { return workspace_mr_; }
232223

233224
private:
234225
int device_id_;
235226
std::unique_ptr<rmm::cuda_stream_pool> streams_;
236227
std::vector<std::shared_ptr<rmm::cuda_stream_pool>> pools_;
237-
std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>> pool_mr_;
238-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_mr_;
228+
std::optional<rmm::mr::pool_memory_resource> pool_mr_;
229+
std::optional<raft::mr::device_resource> workspace_mr_;
239230
std::optional<std::size_t> workspace_allocation_limit_{std::nullopt};
240231
};
241232

@@ -374,8 +365,7 @@ struct device_resources_manager {
374365
}
375366

376367
// Thread-safe setter for workspace memory resources
377-
void set_workspace_memory_resource_(std::shared_ptr<rmm::mr::device_memory_resource> mr,
378-
int device_id)
368+
void set_workspace_memory_resource_(raft::mr::device_resource mr, int device_id)
379369
{
380370
auto lock = get_lock();
381371
if (params_finalized_) {
@@ -387,9 +377,9 @@ struct device_resources_manager {
387377
std::end(params_.workspace_mrs),
388378
[device_id](auto&& pair) { return pair.second == device_id; });
389379
if (iter != std::end(params_.workspace_mrs)) {
390-
iter->first = mr;
380+
iter->first = std::move(mr);
391381
} else {
392-
params_.workspace_mrs.emplace_back(mr, device_id);
382+
params_.workspace_mrs.emplace_back(std::move(mr), device_id);
393383
}
394384
}
395385
}
@@ -548,7 +538,7 @@ struct device_resources_manager {
548538
* @brief Set the workspace memory resource to be used on a specific device
549539
*
550540
* RAFT device_resources objects can be built with a separate memory
551-
* resource for allocating temporary workspaces. If a (non-nullptr) memory
541+
* resource for allocating temporary workspaces. If a memory
552542
* resource is provided by this setter, it will be used as the
553543
* workspace memory resource for all `device_resources` returned for the
554544
* indicated device.
@@ -557,10 +547,10 @@ struct device_resources_manager {
557547
* `raft::device_resources_manager::get_device_resources`, no change will be made,
558548
* and a warning will be emitted.
559549
*/
560-
static void set_workspace_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
550+
static void set_workspace_memory_resource(raft::mr::device_resource mr,
561551
int device_id = device_setter::get_current_device())
562552
{
563-
get_manager().set_workspace_memory_resource_(mr, device_id);
553+
get_manager().set_workspace_memory_resource_(std::move(mr), device_id);
564554
}
565555
};
566556
} // namespace raft

cpp/include/raft/core/device_resources_snmg.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <raft/core/resource/resource_types.hpp>
1111

1212
#include <rmm/cuda_device.hpp>
13-
#include <rmm/mr/device_memory_resource.hpp>
1413
#include <rmm/mr/per_device_resource.hpp>
1514
#include <rmm/mr/pool_memory_resource.hpp>
1615

@@ -105,10 +104,9 @@ class device_resources_snmg : public device_resources {
105104
int device_id = raft::resource::get_device_id(dev_res);
106105
pool_device_ids_.push_back(device_id);
107106

108-
per_device_pools_.push_back(
109-
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
110-
rmm::mr::get_current_device_resource_ref(),
111-
rmm::percent_of_free_device_memory(percent_of_free_memory)));
107+
per_device_pools_.push_back(std::make_unique<rmm::mr::pool_memory_resource>(
108+
rmm::mr::get_current_device_resource_ref(),
109+
rmm::percent_of_free_device_memory(percent_of_free_memory)));
112110
rmm::mr::set_per_device_resource_ref(rmm::cuda_device_id{device_id},
113111
*per_device_pools_.back());
114112
}
@@ -151,8 +149,7 @@ class device_resources_snmg : public device_resources {
151149
}
152150
}
153151
int main_gpu_id_;
154-
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>>
155-
per_device_pools_;
152+
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource>> per_device_pools_;
156153
std::vector<int> pool_device_ids_;
157154
}; // class device_resources_snmg
158155

cpp/include/raft/core/handle.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2019-2023, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -21,9 +21,8 @@ namespace raft {
2121
*/
2222
class handle_t : public raft::device_resources {
2323
public:
24-
handle_t(const handle_t& handle,
25-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource)
26-
: device_resources(handle, workspace_resource)
24+
handle_t(const handle_t& handle, raft::mr::device_resource workspace_resource)
25+
: device_resources(handle, std::move(workspace_resource))
2726
{
2827
}
2928

@@ -43,8 +42,8 @@ class handle_t : public raft::device_resources {
4342
*/
4443
handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
4544
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr},
46-
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource = {nullptr})
47-
: device_resources{stream_view, stream_pool, workspace_resource}
45+
std::optional<raft::mr::device_resource> workspace_resource = std::nullopt)
46+
: device_resources{stream_view, stream_pool, std::move(workspace_resource)}
4847
{
4948
}
5049

0 commit comments

Comments
 (0)