11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2023-2025 , NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2023-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
1212#include < rmm/cuda_stream_pool.hpp>
1313#include < rmm/mr/cuda_memory_resource.hpp>
1414#include < rmm/mr/per_device_resource.hpp>
15+ #include < rmm/mr/pool_memory_resource.hpp>
1516
1617#include < algorithm>
1718#include < memory>
@@ -117,9 +118,7 @@ struct device_resources_manager {
117118 // Optional specification of separate workspace memory resources for each
118119 // device. The integer in each pair indicates the device for this memory
119120 // resource.
120- std::vector<std::pair<std::shared_ptr<rmm::mr::device_memory_resource>, int >> workspace_mrs{};
121-
122- auto get_workspace_memory_resource (int device_id) {}
121+ std::vector<std::pair<raft::mr::device_resource, int >> workspace_mrs{};
123122 } params_;
124123
125124 // This struct stores the underlying resources to be shared among
@@ -152,36 +151,31 @@ struct device_resources_manager {
152151 }()},
153152 pool_mr_{[¶ms, this ]() {
154153 auto scoped_device = device_setter{device_id_};
155- auto result =
156- std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>{nullptr };
154+ auto result = std::optional<rmm::mr::pool_memory_resource>{};
157155 // If max_mem_pool_size is nullopt or non-zero, create a pool memory
158156 // resource
159157 if (params.max_mem_pool_size .value_or (1 ) != 0 ) {
160- auto * upstream =
161- dynamic_cast <rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource ());
162- if (upstream != nullptr ) {
163- result =
164- std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
165- upstream,
166- params.init_mem_pool_size .value_or (rmm::percent_of_free_device_memory (50 )),
167- params.max_mem_pool_size );
168- rmm::mr::set_current_device_resource (result.get ());
169- } else {
170- RAFT_LOG_WARN (
171- " Pool allocation requested, but other memory resource has already been set and "
172- " will not be overwritten" );
173- }
158+ // TODO: reinstate the dynamic_cast<cuda_memory_resource*> guard that
159+ // skipped pool creation when a non-default resource was already set,
160+ // once CCCL exposes resource_cast or an equivalent type-query.
161+ auto upstream = rmm::mr::get_current_device_resource_ref ();
162+ result.emplace (
163+ upstream,
164+ params.init_mem_pool_size .value_or (rmm::percent_of_free_device_memory (50 )),
165+ params.max_mem_pool_size );
166+ rmm::mr::set_current_device_resource (*result);
174167 }
175168 return result;
176169 }()},
177170 workspace_mr_{[¶ms, this ]() {
178- auto result = std::shared_ptr<rmm ::mr::device_memory_resource>{ nullptr };
171+ auto result = std::optional<raft ::mr::device_resource>{ };
179172 auto iter = std::find_if (std::begin (params.workspace_mrs ),
180173 std::end (params.workspace_mrs ),
181174 [this ](auto && pair) { return pair.second == device_id_; });
182175 if (iter != std::end (params.workspace_mrs )) { result = iter->first ; }
183176 return result;
184- }()}
177+ }()},
178+ workspace_allocation_limit_{params.workspace_allocation_limit }
185179 {
186180 }
187181
@@ -216,26 +210,23 @@ struct device_resources_manager {
216210 if (pool_count () != 0 ) { result = pools_[get_thread_id () % pool_count ()]; }
217211 return result;
218212 }
219- // Return a (possibly null) shared_ptr to the pool memory resource
220- // created for this device by the manager
221- [[nodiscard]] auto get_pool_memory_resource () const { return pool_mr_; }
213+ // Return the pool memory resource created for this device by the manager (if any)
214+ [[nodiscard]] auto & get_pool_memory_resource () { return pool_mr_; }
222215 // Return the RAFT workspace allocation limit that will be used by
223216 // `device_resources` returned from this manager
224217 [[nodiscard]] auto get_workspace_allocation_limit () const
225218 {
226219 return workspace_allocation_limit_;
227220 }
228- // Return a (possibly null) shared_ptr to the memory resource that will
229- // be used for workspace allocations by `device_resources` returned from
230- // this manager
231- [[nodiscard]] auto get_workspace_memory_resource () { return workspace_mr_; }
221+ // Return the workspace memory resource for this device (if any)
222+ [[nodiscard]] auto & get_workspace_memory_resource () { return workspace_mr_; }
232223
233224 private:
234225 int device_id_;
235226 std::unique_ptr<rmm::cuda_stream_pool> streams_;
236227 std::vector<std::shared_ptr<rmm::cuda_stream_pool>> pools_;
237- std::shared_ptr <rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> > pool_mr_;
238- std::shared_ptr<rmm ::mr::device_memory_resource > workspace_mr_;
228+ std::optional <rmm::mr::pool_memory_resource> pool_mr_;
229+ std::optional<raft ::mr::device_resource > workspace_mr_;
239230 std::optional<std::size_t > workspace_allocation_limit_{std::nullopt };
240231 };
241232
@@ -374,8 +365,7 @@ struct device_resources_manager {
374365 }
375366
376367 // Thread-safe setter for workspace memory resources
377- void set_workspace_memory_resource_ (std::shared_ptr<rmm::mr::device_memory_resource> mr,
378- int device_id)
368+ void set_workspace_memory_resource_ (raft::mr::device_resource mr, int device_id)
379369 {
380370 auto lock = get_lock ();
381371 if (params_finalized_) {
@@ -387,9 +377,9 @@ struct device_resources_manager {
387377 std::end (params_.workspace_mrs ),
388378 [device_id](auto && pair) { return pair.second == device_id; });
389379 if (iter != std::end (params_.workspace_mrs )) {
390- iter->first = mr ;
380+ iter->first = std::move (mr) ;
391381 } else {
392- params_.workspace_mrs .emplace_back (mr , device_id);
382+ params_.workspace_mrs .emplace_back (std::move (mr) , device_id);
393383 }
394384 }
395385 }
@@ -548,7 +538,7 @@ struct device_resources_manager {
548538 * @brief Set the workspace memory resource to be used on a specific device
549539 *
550540 * RAFT device_resources objects can be built with a separate memory
551- * resource for allocating temporary workspaces. If a (non-nullptr) memory
541+ * resource for allocating temporary workspaces. If a memory
552542 * resource is provided by this setter, it will be used as the
553543 * workspace memory resource for all `device_resources` returned for the
554544 * indicated device.
@@ -557,10 +547,10 @@ struct device_resources_manager {
557547 * `raft::device_resources_manager::get_device_resources`, no change will be made,
558548 * and a warning will be emitted.
559549 */
560- static void set_workspace_memory_resource (std::shared_ptr<rmm:: mr::device_memory_resource> mr,
550+ static void set_workspace_memory_resource (raft:: mr::device_resource mr,
561551 int device_id = device_setter::get_current_device())
562552 {
563- get_manager ().set_workspace_memory_resource_ (mr , device_id);
553+ get_manager ().set_workspace_memory_resource_ (std::move (mr) , device_id);
564554 }
565555};
566556} // namespace raft
0 commit comments