From c67f9fac79c64dc09aae59f2eceab370d9d8023f Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 3 Apr 2026 09:32:27 -0500 Subject: [PATCH 1/4] Migrate RMM usage to CCCL memory resource design Remove device_memory_resource base class usage, de-template all resource and adaptor types, replace pointer-based per-device resource APIs with ref-based equivalents, and update all call sites for the new signatures. Part of rapidsai/rmm#2011. --- cpp/bench/prims/common/benchmark.hpp | 20 +- cpp/bench/prims/matrix/gather.cu | 15 +- cpp/bench/prims/random/subsample.cu | 11 +- cpp/include/raft/core/device_resources.hpp | 26 +- .../raft/core/device_resources_manager.hpp | 104 +------- .../raft/core/device_resources_snmg.hpp | 11 +- cpp/include/raft/core/handle.hpp | 15 +- .../core/resource/device_memory_resource.hpp | 226 +++--------------- .../raft/matrix/detail/select_k-inl.cuh | 4 +- .../raft/matrix/detail/select_radix.cuh | 2 +- .../raft/matrix/detail/select_warpsort.cuh | 6 +- .../sparse/convert/detail/bitmap_to_csr.cuh | 2 +- .../sparse/convert/detail/bitset_to_csr.cuh | 2 +- .../sparse/matrix/detail/select_k-inl.cuh | 2 +- .../raft/util/memory_tracking_resources.hpp | 49 +--- cpp/tests/core/device_resources_manager.cpp | 42 +--- cpp/tests/core/handle.cpp | 28 +-- cpp/tests/core/mdarray.cu | 35 --- cpp/tests/random/multi_variable_gaussian.cu | 4 +- 19 files changed, 101 insertions(+), 503 deletions(-) diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 913729fd03..ffa3c6d82d 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include @@ -33,26 +33,24 @@ namespace raft::bench { */ struct using_pool_memory_res { private: - rmm::mr::device_memory_resource* orig_res_; rmm::mr::cuda_memory_resource cuda_res_{}; - rmm::mr::pool_memory_resource pool_res_; + rmm::mr::pool_memory_resource pool_res_; + cuda::mr::any_resource prev_res_; public: using_pool_memory_res(size_t initial_size, size_t max_size) - : orig_res_(rmm::mr::get_current_device_resource()), - pool_res_(&cuda_res_, initial_size, max_size) + : pool_res_(cuda_res_, initial_size, max_size), + prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_)) { - rmm::mr::set_current_device_resource(&pool_res_); } using_pool_memory_res() - : orig_res_(rmm::mr::get_current_device_resource()), - pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50)) + : pool_res_(cuda_res_, rmm::percent_of_free_device_memory(50)), + prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_)) { - rmm::mr::set_current_device_resource(&pool_res_); } - ~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); } + ~using_pool_memory_res() { rmm::mr::set_current_device_resource_ref(prev_res_); } }; /** diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index 0f9fdf4c34..a182fb5788 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -13,7 +13,7 @@ #include #include -#include +#include #include namespace raft::bench::matrix { @@ -35,18 +35,17 @@ template struct Gather : public fixture { Gather(const GatherParams& p) : params(p), - old_mr(rmm::mr::get_current_device_resource()), - pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)), + pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * (1ULL << 30)), + prev_res_(rmm::mr::set_current_device_resource_ref(pool_mr)), matrix(this->handle), map(this->handle), out(this->handle), stencil(this->handle), matrix_h(this->handle) { - rmm::mr::set_current_device_resource(&pool_mr); } - ~Gather() { rmm::mr::set_current_device_resource(old_mr); } + ~Gather() { rmm::mr::set_current_device_resource_ref(prev_res_); } void allocate_data(const ::benchmark::State& state) override { @@ -107,8 +106,8 @@ struct Gather : public fixture { private: GatherParams params; - rmm::mr::device_memory_resource* old_mr; - rmm::mr::pool_memory_resource pool_mr; + rmm::mr::pool_memory_resource pool_mr; + cuda::mr::any_resource prev_res_; raft::device_matrix matrix, out; raft::host_matrix matrix_h; raft::device_vector stencil; diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu index 004b940f7f..21f9efacf1 100644 --- a/cpp/bench/prims/random/subsample.cu +++ b/cpp/bench/prims/random/subsample.cu @@ -50,16 +50,15 @@ template struct sample : public fixture { sample(const sample_inputs& p) : params(p), - old_mr(rmm::mr::get_current_device_resource()), - pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB), + pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * GiB), + prev_mr(rmm::mr::set_current_device_resource_ref(pool_mr)), in(make_device_vector(res, p.n_samples)), out(make_device_vector(res, p.n_train)) { - rmm::mr::set_current_device_resource(&pool_mr); raft::random::RngState r(123456ULL); } - ~sample() { rmm::mr::set_current_device_resource(old_mr); } + ~sample() { rmm::mr::set_current_device_resource_ref(prev_mr); } void run_benchmark(::benchmark::State& state) override { std::ostringstream label_stream; @@ -81,8 +80,8 @@ struct sample : public fixture { private: float GiB = 1073741824.0f; raft::device_resources res; - rmm::mr::device_memory_resource* old_mr; - rmm::mr::pool_memory_resource pool_mr; + rmm::mr::pool_memory_resource pool_mr; + cuda::mr::any_resource prev_mr; sample_inputs params; raft::device_vector out, in; }; // struct sample diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index aa697d2e18..a672a34d3b 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -26,7 +26,7 @@ #include #include -#include +#include #include @@ -51,15 +51,6 @@ namespace raft { */ class device_resources : public resources { public: - device_resources(const device_resources& handle, - std::shared_ptr workspace_resource, - std::optional allocation_limit = std::nullopt) - : resources{handle} - { - // replace the resource factory for the workspace_resources - resource::set_workspace_resource(*this, workspace_resource, allocation_limit); - } - device_resources(const device_resources& handle) : resources{handle} {} device_resources(device_resources&&) = delete; device_resources& operator=(device_resources&&) = delete; @@ -70,15 +61,9 @@ class device_resources : public resources { * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) - * @param[in] workspace_resource an optional resource used by some functions for allocating - * temporary workspaces. - * @param[in] allocation_limit the total amount of memory in bytes available to the temporary - * workspace resources. */ device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}, - std::shared_ptr workspace_resource = {nullptr}, - std::optional allocation_limit = std::nullopt) + std::shared_ptr stream_pool = {nullptr}) : resources{} { resources::add_resource_factory(std::make_shared()); @@ -86,9 +71,6 @@ class device_resources : public resources { std::make_shared(stream_view)); resources::add_resource_factory( std::make_shared(stream_pool)); - if (workspace_resource) { - resource::set_workspace_resource(*this, workspace_resource, allocation_limit); - } } /** Destroys all held-up resources */ @@ -214,7 +196,7 @@ class device_resources : public resources { return resource::get_subcomm(*this, key); } - rmm::mr::device_memory_resource* get_workspace_resource() const + rmm::mr::limiting_resource_adaptor* get_workspace_resource() const { return resource::get_workspace_resource(*this); } diff --git a/cpp/include/raft/core/device_resources_manager.hpp b/cpp/include/raft/core/device_resources_manager.hpp index 7d615cec5f..24f0a7c365 100644 --- a/cpp/include/raft/core/device_resources_manager.hpp +++ b/cpp/include/raft/core/device_resources_manager.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -114,12 +115,6 @@ struct device_resources_manager { std::optional max_mem_pool_size{std::size_t{}}; // Limit on workspace memory for the returned device_resources object std::optional workspace_allocation_limit{std::nullopt}; - // Optional specification of separate workspace memory resources for each - // device. The integer in each pair indicates the device for this memory - // resource. - std::vector, int>> workspace_mrs{}; - - auto get_workspace_memory_resource(int device_id) {} } params_; // This struct stores the underlying resources to be shared among @@ -152,35 +147,18 @@ struct device_resources_manager { }()}, pool_mr_{[¶ms, this]() { auto scoped_device = device_setter{device_id_}; - auto result = - std::shared_ptr>{nullptr}; + auto result = std::optional{}; // If max_mem_pool_size is nullopt or non-zero, create a pool memory // resource if (params.max_mem_pool_size.value_or(1) != 0) { - auto* upstream = - dynamic_cast(rmm::mr::get_current_device_resource()); - if (upstream != nullptr) { - result = - std::make_shared>( - upstream, - params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), - params.max_mem_pool_size); - rmm::mr::set_current_device_resource(result.get()); - } else { - RAFT_LOG_WARN( - "Pool allocation requested, but other memory resource has already been set and " - "will not be overwritten"); - } + auto upstream = rmm::mr::get_current_device_resource_ref(); + result.emplace( + upstream, + params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), + params.max_mem_pool_size); + rmm::mr::set_current_device_resource_ref(*result); } return result; - }()}, - workspace_mr_{[¶ms, this]() { - auto result = std::shared_ptr{nullptr}; - auto iter = std::find_if(std::begin(params.workspace_mrs), - std::end(params.workspace_mrs), - [this](auto&& pair) { return pair.second == device_id_; }); - if (iter != std::end(params.workspace_mrs)) { result = iter->first; } - return result; }()} { } @@ -216,27 +194,14 @@ struct device_resources_manager { if (pool_count() != 0) { result = pools_[get_thread_id() % pool_count()]; } return result; } - // Return a (possibly null) shared_ptr to the pool memory resource - // created for this device by the manager - [[nodiscard]] auto get_pool_memory_resource() const { return pool_mr_; } - // Return the RAFT workspace allocation limit that will be used by - // `device_resources` returned from this manager - [[nodiscard]] auto get_workspace_allocation_limit() const - { - return workspace_allocation_limit_; - } - // Return a (possibly null) shared_ptr to the memory resource that will - // be used for workspace allocations by `device_resources` returned from - // this manager - [[nodiscard]] auto get_workspace_memory_resource() { return workspace_mr_; } + // Return the pool memory resource created for this device by the manager (if any) + [[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; } private: int device_id_; std::unique_ptr streams_; std::vector> pools_; - std::shared_ptr> pool_mr_; - std::shared_ptr workspace_mr_; - std::optional workspace_allocation_limit_{std::nullopt}; + std::optional pool_mr_; }; // Mutex used to lock access to shared data until after the first @@ -290,10 +255,7 @@ struct device_resources_manager { auto scoped_device = device_setter(device_id); // Build the device_resources object for this thread out of shared // components - thread_resources[device_id].emplace(component_iter->get_stream(), - component_iter->get_pool(), - component_iter->get_workspace_memory_resource(), - component_iter->get_workspace_allocation_limit()); + thread_resources[device_id].emplace(component_iter->get_stream(), component_iter->get_pool()); } return thread_resources[device_id].value(); @@ -373,27 +335,6 @@ struct device_resources_manager { } } - // Thread-safe setter for workspace memory resources - void set_workspace_memory_resource_(std::shared_ptr mr, - int device_id) - { - auto lock = get_lock(); - if (params_finalized_) { - RAFT_LOG_WARN( - "Attempted to set device_resources_manager properties after resources have already been " - "retrieved"); - } else { - auto iter = std::find_if(std::begin(params_.workspace_mrs), - std::end(params_.workspace_mrs), - [device_id](auto&& pair) { return pair.second == device_id; }); - if (iter != std::end(params_.workspace_mrs)) { - iter->first = mr; - } else { - params_.workspace_mrs.emplace_back(mr, device_id); - } - } - } - // Retrieve the instance of this singleton static auto& get_manager() { @@ -543,24 +484,5 @@ struct device_resources_manager { set_init_mem_pool_size(init_mem); set_max_mem_pool_size(max_mem); } - - /** - * @brief Set the workspace memory resource to be used on a specific device - * - * RAFT device_resources objects can be built with a separate memory - * resource for allocating temporary workspaces. If a (non-nullptr) memory - * resource is provided by this setter, it will be used as the - * workspace memory resource for all `device_resources` returned for the - * indicated device. - * - * If called after the first call to - * `raft::device_resources_manager::get_device_resources`, no change will be made, - * and a warning will be emitted. - */ - static void set_workspace_memory_resource(std::shared_ptr mr, - int device_id = device_setter::get_current_device()) - { - get_manager().set_workspace_memory_resource_(mr, device_id); - } }; } // namespace raft diff --git a/cpp/include/raft/core/device_resources_snmg.hpp b/cpp/include/raft/core/device_resources_snmg.hpp index b591561724..01c2ab51d5 100644 --- a/cpp/include/raft/core/device_resources_snmg.hpp +++ b/cpp/include/raft/core/device_resources_snmg.hpp @@ -10,7 +10,6 @@ #include #include -#include #include #include @@ -105,10 +104,9 @@ class device_resources_snmg : public device_resources { int device_id = raft::resource::get_device_id(dev_res); pool_device_ids_.push_back(device_id); - per_device_pools_.push_back( - std::make_unique>( - rmm::mr::get_current_device_resource_ref(), - rmm::percent_of_free_device_memory(percent_of_free_memory))); + per_device_pools_.push_back(std::make_unique( + rmm::mr::get_current_device_resource_ref(), + rmm::percent_of_free_device_memory(percent_of_free_memory))); rmm::mr::set_per_device_resource_ref(rmm::cuda_device_id{device_id}, *per_device_pools_.back()); } @@ -151,8 +149,7 @@ class device_resources_snmg : public device_resources { } } int main_gpu_id_; - std::vector>> - per_device_pools_; + std::vector> per_device_pools_; std::vector pool_device_ids_; }; // class device_resources_snmg diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 08a4b101ed..85486ab101 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2019-2023, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -21,12 +21,6 @@ namespace raft { */ class handle_t : public raft::device_resources { public: - handle_t(const handle_t& handle, - std::shared_ptr workspace_resource) - : device_resources(handle, workspace_resource) - { - } - handle_t(const handle_t& handle) : device_resources{handle} {} handle_t(handle_t&&) = delete; @@ -38,13 +32,10 @@ class handle_t : public raft::device_resources { * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) - * @param[in] workspace_resource an optional resource used by some functions for allocating - * temporary workspaces. */ handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}, - std::shared_ptr workspace_resource = {nullptr}) - : device_resources{stream_view, stream_pool, workspace_resource} + std::shared_ptr stream_pool = {nullptr}) + : device_resources{stream_view, stream_pool} { } diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 6ce2387603..17c929aae9 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -11,7 +11,6 @@ #include #include -#include #include #include #include @@ -30,91 +29,21 @@ namespace raft::resource { class device_memory_resource : public resource { public: - explicit device_memory_resource(std::shared_ptr mr) : mr_(mr) - { - if (auto* b = dynamic_cast(mr_.get())) { - any_mr_ = b->upstream(); - mr_.reset(); - } - } explicit device_memory_resource(raft::mr::device_resource ar) : any_mr_(std::move(ar)) {} ~device_memory_resource() override = default; - auto get_resource() -> void* override - { - if (mr_) return mr_.get(); - if (!bridge_) bridge_ = std::make_unique(*any_mr_); - return bridge_.get(); - } - - /** - * @brief Construct a device_async_resource_ref from a device_memory_resource pointer, - * unwrapping an any_resource_bridge if present. - * - * If the pointer is a bridge created by this class, the returned ref points - * directly to the enclosed any_resource, avoiding an extra virtual dispatch layer. - */ - static auto make_ref(rmm::mr::device_memory_resource* dmr) -> rmm::device_async_resource_ref - { - if (auto* b = dynamic_cast(dmr)) { - return rmm::device_async_resource_ref{b->upstream()}; - } - return rmm::device_async_resource_ref{*dmr}; - } + auto get_resource() -> void* override { return &any_mr_; } private: - class any_resource_bridge : public rmm::mr::device_memory_resource { - public: - explicit any_resource_bridge(cuda::mr::any_resource& upstream) - : upstream_(upstream) - { - } - - auto upstream() noexcept -> cuda::mr::any_resource& - { - return upstream_; - } - - protected: - void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override - { - return upstream_.allocate(stream, bytes); - } - - void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream) noexcept override - { - upstream_.deallocate(stream, ptr, bytes); - } - - [[nodiscard]] bool do_is_equal( - rmm::mr::device_memory_resource const& other) const noexcept override - { - auto const* o = dynamic_cast(&other); - return o != nullptr && upstream_ == o->upstream_; - } - - private: - cuda::mr::any_resource& upstream_; - }; - - std::shared_ptr mr_; - std::optional any_mr_; - mutable std::unique_ptr bridge_; + raft::mr::device_resource any_mr_; }; class limiting_memory_resource : public resource { public: - limiting_memory_resource(std::shared_ptr mr, - std::size_t allocation_limit, - std::optional alignment) - : upstream_(mr), mr_(make_adaptor(mr.get(), allocation_limit, alignment)) - { - } - limiting_memory_resource(raft::mr::device_resource ar, std::size_t allocation_limit, std::optional alignment) : any_upstream_(std::move(ar)), - mr_(make_adaptor(rmm::device_async_resource_ref{*any_upstream_}, allocation_limit, alignment)) + mr_(make_adaptor(rmm::device_async_resource_ref{any_upstream_}, allocation_limit, alignment)) { } @@ -123,14 +52,13 @@ class limiting_memory_resource : public resource { ~limiting_memory_resource() override = default; private: - std::shared_ptr upstream_; - std::optional any_upstream_; - rmm::mr::limiting_resource_adaptor mr_; + raft::mr::device_resource any_upstream_; + rmm::mr::limiting_resource_adaptor mr_; - static inline auto make_adaptor(rmm::mr::device_memory_resource* upstream, + static inline auto make_adaptor(rmm::device_async_resource_ref upstream, std::size_t limit, std::optional alignment) - -> rmm::mr::limiting_resource_adaptor + -> rmm::mr::limiting_resource_adaptor { if (alignment.has_value()) { return rmm::mr::limiting_resource_adaptor(upstream, limit, alignment.value()); @@ -138,19 +66,6 @@ class limiting_memory_resource : public resource { return rmm::mr::limiting_resource_adaptor(upstream, limit); } } - - static inline auto make_adaptor(rmm::device_async_resource_ref upstream, - std::size_t limit, - std::optional alignment) - -> rmm::mr::limiting_resource_adaptor - { - if (alignment.has_value()) { - return rmm::mr::limiting_resource_adaptor( - upstream, limit, alignment.value()); - } else { - return rmm::mr::limiting_resource_adaptor(upstream, limit); - } - } }; /** @@ -159,11 +74,8 @@ class limiting_memory_resource : public resource { */ class large_workspace_resource_factory : public resource_factory { public: - explicit large_workspace_resource_factory( - std::shared_ptr mr = {nullptr}) - : mr_{mr ? mr - : std::shared_ptr{ - rmm::mr::get_current_device_resource(), void_op{}}} + large_workspace_resource_factory() + : any_mr_(raft::mr::device_resource{rmm::mr::get_current_device_resource_ref()}) { } @@ -175,15 +87,10 @@ class large_workspace_resource_factory : public resource_factory { { return resource_type::LARGE_WORKSPACE_RESOURCE; } - auto make_resource() -> resource* override - { - if (any_mr_) return new device_memory_resource(*any_mr_); - return new device_memory_resource(mr_); - } + auto make_resource() -> resource* override { return new device_memory_resource(any_mr_); } private: - std::shared_ptr mr_; - std::optional any_mr_; + raft::mr::device_resource any_mr_; }; /** @@ -192,19 +99,9 @@ class large_workspace_resource_factory : public resource_factory { */ class workspace_resource_factory : public resource_factory { public: - explicit workspace_resource_factory( - std::shared_ptr mr = {nullptr}, - std::optional allocation_limit = std::nullopt, - std::optional alignment = std::nullopt) - // default_allocation_limit() is relatively heavy, skip it while unnecessary - : allocation_limit_(allocation_limit.has_value() ? allocation_limit.value() - : default_allocation_limit()), - alignment_(alignment), - mr_(mr ? mr : default_plain_resource()) - { - } - - explicit workspace_resource_factory(raft::mr::device_resource mr, + explicit workspace_resource_factory(raft::mr::device_resource mr = + raft::mr::device_resource{ + rmm::mr::get_current_device_resource_ref()}, std::optional allocation_limit = std::nullopt, std::optional alignment = std::nullopt) : allocation_limit_(allocation_limit.has_value() ? allocation_limit.value() @@ -217,13 +114,11 @@ class workspace_resource_factory : public resource_factory { auto get_resource_type() -> resource_type override { return resource_type::WORKSPACE_RESOURCE; } auto make_resource() -> resource* override { - if (any_mr_) return new limiting_memory_resource(*any_mr_, allocation_limit_, alignment_); - return new limiting_memory_resource(mr_, allocation_limit_, alignment_); + return new limiting_memory_resource(any_mr_, allocation_limit_, alignment_); } /** Construct a sensible default pool memory resource. */ - static inline auto default_pool_resource(std::size_t limit) - -> std::shared_ptr + static inline auto default_pool_resource(std::size_t limit) -> raft::mr::device_resource { // Set the default granularity to 1 GiB constexpr std::size_t kOneGb = 1024lu * 1024lu * 1024lu; @@ -240,35 +135,20 @@ class workspace_resource_factory : public resource_factory { // resource adaptor bad_alloc error than into the pool bad_alloc error. // 2) The pool doesn't grab too much memory on top of the 'limit'. auto max_size = std::min(limit + kOneGb / 2lu, limit * 3lu / 2lu); - auto upstream = rmm::mr::get_current_device_resource(); + auto upstream = rmm::mr::get_current_device_resource_ref(); RAFT_LOG_DEBUG( "Setting the workspace pool resource; memory limit = %zu, initial pool size = %zu, max pool " "size = %zu.", limit, min_size, max_size); - return std::make_shared>( - upstream, min_size, max_size); - } - - /** - * Get the global memory resource wrapped into an unmanaged shared_ptr (with no deleter). - * - * Note: the lifetime of the underlying `rmm::mr::get_current_device_resource()` is managed - * somewhere else, since it's passed by a raw pointer. Hence, this shared_ptr wrapper is not - * allowed to delete the pointer on destruction. - */ - static inline auto default_plain_resource() -> std::shared_ptr - { - return std::shared_ptr{rmm::mr::get_current_device_resource(), - void_op{}}; + return raft::mr::device_resource{rmm::mr::pool_memory_resource(upstream, min_size, max_size)}; } private: std::size_t allocation_limit_; std::optional alignment_; - std::shared_ptr mr_; - std::optional any_mr_; + raft::mr::device_resource any_mr_; static inline auto default_allocation_limit() -> std::size_t { @@ -284,14 +164,12 @@ class workspace_resource_factory : public resource_factory { namespace detail { -inline auto get_workspace_adaptor(resources const& res) - -> rmm::mr::limiting_resource_adaptor* +inline auto get_workspace_adaptor(resources const& res) -> rmm::mr::limiting_resource_adaptor* { if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { res.add_resource_factory(std::make_shared()); } - return res.get_resource>( - resource_type::WORKSPACE_RESOURCE); + return res.get_resource(resource_type::WORKSPACE_RESOURCE); } } // namespace detail @@ -306,8 +184,7 @@ inline auto get_workspace_adaptor(resources const& res) * @param res raft resources object for managing resources * @return pointer to the workspace limiting_resource_adaptor */ -inline auto get_workspace_resource(resources const& res) - -> rmm::mr::limiting_resource_adaptor* +inline auto get_workspace_resource(resources const& res) -> rmm::mr::limiting_resource_adaptor* { return detail::get_workspace_adaptor(res); } @@ -375,28 +252,9 @@ inline void set_workspace_resource(resources const& res, std::make_shared(std::move(mr), allocation_limit, alignment)); } -/** - * @deprecated Use the overload taking raft::mr::device_resource instead. - * - * @param res raft resources object for managing resources - * @param mr an RMM device_memory_resource - * @param allocation_limit - * the total amount of memory in bytes available to the temporary workspace resources. - * @param alignment optional alignment requirements passed to allocations - */ -[[deprecated("use the overload taking raft::mr::device_resource instead")]] -inline void set_workspace_resource(resources const& res, - std::shared_ptr mr, - std::optional allocation_limit = std::nullopt, - std::optional alignment = std::nullopt) -{ - res.add_resource_factory( - std::make_shared(mr, allocation_limit, alignment)); -}; - /** * Set the temporary workspace resource to a pool on top of the global memory resource - * (`rmm::mr::get_current_device_resource()`. + * (`rmm::mr::get_current_device_resource_ref()`). * * @param res raft resources object for managing resources * @param allocation_limit @@ -416,7 +274,7 @@ inline void set_workspace_to_pool_resource( /** * Set the temporary workspace resource the same as the global memory resource - * (`rmm::mr::get_current_device_resource()`. + * (`rmm::mr::get_current_device_resource_ref()`). * * Note, the workspace resource is always limited; the limit here defines how much of the global * memory resource can be consumed by the workspace allocations. @@ -429,7 +287,9 @@ inline void set_workspace_to_global_resource( resources const& res, std::optional allocation_limit = std::nullopt) { res.add_resource_factory(std::make_shared( - workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt)); + raft::mr::device_resource{rmm::mr::get_current_device_resource_ref()}, + allocation_limit, + std::nullopt)); }; /** @@ -443,8 +303,8 @@ inline auto get_large_workspace_resource_ref(resources const& res) -> rmm::devic if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) { res.add_resource_factory(std::make_shared()); } - return device_memory_resource::make_ref( - res.get_resource(resource_type::LARGE_WORKSPACE_RESOURCE)); + return rmm::device_async_resource_ref{ + *res.get_resource(resource_type::LARGE_WORKSPACE_RESOURCE)}; } /** @@ -458,34 +318,6 @@ inline void set_large_workspace_resource(resources const& res, raft::mr::device_ res.add_resource_factory(std::make_shared(std::move(mr))); } -/** - * @deprecated Use get_large_workspace_resource_ref() instead. - * - * @param res raft resources object for managing resources - * @return pointer to the large workspace device memory resource - */ -[[deprecated("use get_large_workspace_resource_ref() instead")]] -inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource* -{ - if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) { - res.add_resource_factory(std::make_shared()); - } - return res.get_resource(resource_type::LARGE_WORKSPACE_RESOURCE); -}; - -/** - * @deprecated Use the overload taking raft::mr::device_resource instead. - * - * @param res raft resources object for managing resources - * @param mr an RMM device_memory_resource - */ -[[deprecated("use the overload taking raft::mr::device_resource instead")]] -inline void set_large_workspace_resource(resources const& res, - std::shared_ptr mr) -{ - res.add_resource_factory(std::make_shared(mr)); -}; - /** @} */ } // namespace raft::resource diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index aacde43135..37411ba0bd 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -84,7 +84,7 @@ void segmented_sort_by_key(raft::resources const& handle, bool asc) { auto stream = resource::get_cuda_stream(handle); - auto mr = resource::get_workspace_resource(handle); + auto mr = resource::get_workspace_resource_ref(handle); auto out_inds = raft::make_device_mdarray(handle, mr, raft::make_extents(n_elements)); auto out_dists = @@ -275,7 +275,7 @@ void select_k(raft::resources const& handle, } if (sorted) { auto offsets = make_device_mdarray( - handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); + handle, resource::get_workspace_resource_ref(handle), make_extents(batch_size + 1)); raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); auto keys = raft::make_device_vector_view(out_val, (IdxT)(batch_size * k)); diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 2cdd2465c7..a6dd7e0ce5 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -1267,7 +1267,7 @@ void select_k(raft::resources const& res, "CSR layout requires a non-null indptr array (len_i)!"); auto stream = resource::get_cuda_stream(res); - auto mr = resource::get_workspace_resource(res); + auto mr = resource::get_workspace_resource_ref(res); if (k == len && RowLayout::is_uniform) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index a7b2503b24..a480743664 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -1133,7 +1133,7 @@ void select_k_impl(raft::resources const& res, out_idx, select_min, resource::get_cuda_stream(res), - resource::get_workspace_resource(res)); + resource::get_workspace_resource_ref(res)); } /** @@ -1210,7 +1210,7 @@ void select_k(raft::resources const& res, out_idx, select_min, resource::get_cuda_stream(res), - resource::get_workspace_resource(res)); + resource::get_workspace_resource_ref(res)); } else { calc_launch_parameter( res, batch_size, len, k, &num_of_block, &num_of_warp); @@ -1226,7 +1226,7 @@ void select_k(raft::resources const& res, out_idx, select_min, resource::get_cuda_stream(res), - resource::get_workspace_resource(res)); + resource::get_workspace_resource_ref(res)); } } diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index 25217452d5..42266cee71 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -296,7 +296,7 @@ void bitmap_to_csr(raft::resources const& handle, sub_nnz_size, bits_per_sub_col); - rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); + rmm::device_async_resource_ref device_memory = resource::get_workspace_resource_ref(handle); rmm::device_uvector sub_nnz(sub_nnz_size + 1, stream, device_memory); calc_nnz_by_rows(handle, diff --git a/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh index c0aadc4243..efb418434e 100644 --- a/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitset_to_csr.cuh @@ -109,7 +109,7 @@ void bitset_to_csr(raft::resources const& handle, sub_nnz_size, bits_per_sub_col); - rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); + rmm::device_async_resource_ref device_memory = resource::get_workspace_resource_ref(handle); rmm::device_uvector sub_nnz(sub_nnz_size + 1, stream, device_memory); calc_nnz_by_rows(handle, diff --git a/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh b/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh index e7d6d2b393..b788f5244c 100644 --- a/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh @@ -131,7 +131,7 @@ void select_k(raft::resources const& handle, if (sorted) { auto offsets = make_device_mdarray( - handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); + handle, resource::get_workspace_resource_ref(handle), make_extents(batch_size + 1)); raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); auto keys = diff --git a/cpp/include/raft/util/memory_tracking_resources.hpp b/cpp/include/raft/util/memory_tracking_resources.hpp index e921fdd087..4891b25aa7 100644 --- a/cpp/include/raft/util/memory_tracking_resources.hpp +++ b/cpp/include/raft/util/memory_tracking_resources.hpp @@ -15,7 +15,6 @@ #include #include -#include #include #include @@ -108,9 +107,6 @@ class memory_tracking_resources : public resources { { report_.stop(); raft::mr::set_default_host_resource(old_host_ref_); - // Restore pointer map first (also overwrites ref map), then restore the - // original ref map separately, since the two may have been set independently. - rmm::mr::set_current_device_resource(old_device_mr_); rmm::mr::set_current_device_resource_ref(old_device_ref_); } @@ -131,7 +127,6 @@ class memory_tracking_resources : public resources { owned_stream_(std::move(owned_stream)), report_(out_override ? *out_override : *owned_stream_, sample_interval), old_host_ref_(raft::mr::get_default_host_resource()), - old_device_mr_(rmm::mr::get_current_device_resource()), old_device_ref_(rmm::mr::get_current_device_resource_ref()) { init(); @@ -146,7 +141,6 @@ class memory_tracking_resources : public resources { raft::mr::resource_monitor report_; raft::mr::host_resource_ref old_host_ref_; - rmm::mr::device_memory_resource* old_device_mr_; rmm::device_async_resource_ref old_device_ref_; std::size_t saved_ws_limit_{}; @@ -157,37 +151,7 @@ class memory_tracking_resources : public resources { using device_stats_t = raft::mr::statistics_adaptor; using device_notify_t = raft::mr::notifying_adaptor; - // Bridge: exposes device_notify_t as an rmm::mr::device_memory_resource so - // that set_current_device_resource(ptr) updates both the pointer-based and - // the ref-based global device resource maps in RMM. - class device_tracking_bridge : public rmm::mr::device_memory_resource { - device_notify_t adaptor_; - - protected: - void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override - { - return adaptor_.allocate(cuda::stream_ref{stream.value()}, bytes); - } - void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream) noexcept override - { - adaptor_.deallocate(cuda::stream_ref{stream.value()}, ptr, bytes); - } - [[nodiscard]] bool do_is_equal( - rmm::mr::device_memory_resource const& other) const noexcept override - { - return this == &other; - } - - public: - explicit device_tracking_bridge(device_notify_t adaptor) : adaptor_(std::move(adaptor)) {} - - [[nodiscard]] auto adaptor_ref() noexcept -> cuda::mr::resource_ref - { - return adaptor_; - } - }; - - std::unique_ptr device_bridge_; + std::unique_ptr device_adaptor_; void init() { @@ -230,16 +194,11 @@ class memory_tracking_resources : public resources { } // --- Device (global) --- - // Use set_current_device_resource(ptr) to update both the pointer map and the ref map, - // then overwrite the ref map to point directly at the adaptor (skipping the bridge). { - rmm::device_async_resource_ref dev_ref{*old_device_mr_}; - device_stats_t sa{dev_ref}; + device_stats_t sa{old_device_ref_}; report_.register_source("device", sa.get_stats()); - device_bridge_ = std::make_unique( - device_notify_t{std::move(sa), report_.get_notifier()}); - rmm::mr::set_current_device_resource(device_bridge_.get()); - rmm::mr::set_current_device_resource_ref(device_bridge_->adaptor_ref()); + device_adaptor_ = std::make_unique(std::move(sa), report_.get_notifier()); + rmm::mr::set_current_device_resource_ref(*device_adaptor_); } // --- Workspace (track upstream to preserve limiting_resource_adaptor) --- diff --git a/cpp/tests/core/device_resources_manager.cpp b/cpp/tests/core/device_resources_manager.cpp index f88ffb9635..fa833547bc 100644 --- a/cpp/tests/core/device_resources_manager.cpp +++ b/cpp/tests/core/device_resources_manager.cpp @@ -1,16 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include #include #include -#include -#include #include -#include -#include #include @@ -39,7 +35,6 @@ TEST(DeviceResourcesManager, ObeysSetters) auto pools_per_device = 3; auto streams_per_pool = 7; auto workspace_limit = 2048; - auto workspace_init = 1024; device_resources_manager::set_streams_per_device(streams_per_device); device_resources_manager::set_stream_pools_per_device(pools_per_device, streams_per_pool); device_resources_manager::set_mem_pool(); @@ -50,31 +45,6 @@ TEST(DeviceResourcesManager, ObeysSetters) // Provide lock for counting unique objects auto mtx = std::mutex{}; - auto workspace_mrs = - std::array>, 2>{ - nullptr, nullptr}; - auto alternate_workspace_mrs = std::array, 2>{}; - auto upstream_mrs = std::array{ - dynamic_cast( - rmm::mr::get_per_device_resource(rmm::cuda_device_id{devices[0]})), - dynamic_cast( - rmm::mr::get_per_device_resource(rmm::cuda_device_id{devices[1]}))}; - - for (auto i = std::size_t{}; i < devices.size(); ++i) { - auto scoped_device = device_setter{devices[i]}; - if (upstream_mrs[i] == nullptr) { - RAFT_LOG_WARN( - "RMM memory resource already set. Tests for device_resources_manger will be incomplete."); - } else { - workspace_mrs[i] = - std::make_shared>( - upstream_mrs[i], workspace_init, workspace_limit); - alternate_workspace_mrs[i] = std::make_shared(); - } - } - - device_resources_manager::set_workspace_memory_resource(workspace_mrs[0], devices[0]); - device_resources_manager::set_workspace_memory_resource(workspace_mrs[1], devices[1]); // Suppress the many warnings from testing use of setters after initial // get_device_resources call @@ -103,14 +73,6 @@ TEST(DeviceResourcesManager, ObeysSetters) auto const& pool = res.get_stream_pool(); EXPECT_EQ(streams_per_pool, pool.get_pool_size()); - auto* mr = dynamic_cast*>( - rmm::mr::get_current_device_resource()); - - if (upstream_mrs[i % devices.size()] != nullptr) { - // Expect that the current memory resource is a pool memory resource as requested - EXPECT_NE(mr, nullptr); - } - { auto lock = std::unique_lock{mtx}; unique_streams[device].insert(primary_stream); @@ -121,8 +83,6 @@ TEST(DeviceResourcesManager, ObeysSetters) device_resources_manager::set_stream_pools_per_device(pools_per_device - 1); device_resources_manager::set_mem_pool(); device_resources_manager::set_workspace_allocation_limit(1024); - device_resources_manager::set_workspace_memory_resource( - alternate_workspace_mrs[i % devices.size()], devices[i % devices.size()]); } EXPECT_EQ(streams_per_device, unique_streams[devices[0]].size()); diff --git a/cpp/tests/core/handle.cpp b/cpp/tests/core/handle.cpp index 322e56da91..928490c8ab 100644 --- a/cpp/tests/core/handle.cpp +++ b/cpp/tests/core/handle.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -12,7 +12,7 @@ #include #include -#include +#include #include #include @@ -275,38 +275,32 @@ TEST(Raft, WorkspaceResource) resource::get_workspace_resource(handle)->get_upstream_resource()}; // Let's create a pooled resource - auto pool_mr = std::shared_ptr{new rmm::mr::pool_memory_resource( - rmm::mr::get_current_device_resource(), rmm::percent_of_free_device_memory(50))}; + raft::mr::device_resource pool_mr{rmm::mr::pool_memory_resource( + rmm::mr::get_current_device_resource_ref(), rmm::percent_of_free_device_memory(50))}; // A tiny workspace of 1MB size_t max_size = 1024 * 1024; // Replace the resource resource::set_workspace_resource(handle, pool_mr, max_size); - auto new_mr = resource::get_workspace_resource(handle); - - // By this point, the orig_mr likely points to a non-existent resource; don't dereference! - ASSERT_NE(orig_mr, rmm::device_async_resource_ref{new_mr}); - ASSERT_EQ(rmm::device_async_resource_ref{pool_mr.get()}, new_mr->get_upstream_resource()); - // We can safely reset pool_mr, because the shared_ptr to the pool memory stays in the resource - pool_mr.reset(); + auto* new_mr = resource::get_workspace_resource(handle); auto stream = resource::get_cuda_stream(handle); - rmm::device_buffer buf(max_size / 2, stream, new_mr); + rmm::device_buffer buf(max_size / 2, stream, *new_mr); // Note, the underlying pool allocator likely uses more space than reported here ASSERT_EQ(max_size, resource::get_workspace_total_bytes(handle)); ASSERT_EQ(buf.size(), resource::get_workspace_used_bytes(handle)); ASSERT_EQ(max_size - buf.size(), resource::get_workspace_free_bytes(handle)); - // this should throw, becaise we partially used the space. - ASSERT_THROW((rmm::device_buffer{max_size, stream, new_mr}), rmm::bad_alloc); + // this should throw, because we partially used the space. + ASSERT_THROW((rmm::device_buffer{max_size, stream, *new_mr}), rmm::bad_alloc); } TEST(Raft, WorkspaceResourceCopy) { raft::handle_t res; - auto orig_mr = resource::get_workspace_resource(res); + auto* orig_mr = resource::get_workspace_resource(res); auto orig_size = resource::get_workspace_total_bytes(res); { @@ -314,8 +308,8 @@ TEST(Raft, WorkspaceResourceCopy) raft::resources tmp_res(res); resource::set_workspace_resource( tmp_res, - std::shared_ptr{new rmm::mr::pool_memory_resource( - rmm::mr::get_current_device_resource(), rmm::percent_of_free_device_memory(50))}, + raft::mr::device_resource{rmm::mr::pool_memory_resource( + rmm::mr::get_current_device_resource_ref(), rmm::percent_of_free_device_memory(50))}, orig_size * 2); ASSERT_EQ(orig_mr, resource::get_workspace_resource(res)); diff --git a/cpp/tests/core/mdarray.cu b/cpp/tests/core/mdarray.cu index a434b7e821..5c56177571 100644 --- a/cpp/tests/core/mdarray.cu +++ b/cpp/tests/core/mdarray.cu @@ -1037,39 +1037,4 @@ void test_mdarray_unravel() TEST(MDArray, Unravel) { test_mdarray_unravel(); } -void test_device_resource_bridge_unwrap() -{ - auto stream = rmm::cuda_stream_default; - - // holder2 is created from a bridge wrapping holder1's any_resource. - // After holder1 (and its bridge) are destroyed, holder2 must still work — - // proving the bridge was unwrapped and the any_resource was copied out. - std::unique_ptr holder2; - { - auto any_res = raft::mr::device_resource{rmm::mr::get_current_device_resource_ref()}; - resource::device_memory_resource holder1{any_res}; - - // get_resource() lazily creates a bridge - auto* bridge_ptr = static_cast(holder1.get_resource()); - - // Wrap bridge in a non-owning shared_ptr; holder1 owns the bridge - auto shared_bridge = std::shared_ptr( - bridge_ptr, [](rmm::mr::device_memory_resource*) {}); - - // The shared_ptr constructor detects the bridge and copies the any_resource - holder2 = std::make_unique(shared_bridge); - } - // holder1 is destroyed; the bridge is freed. - // holder2 survives because it copied the any_resource, not the bridge pointer. - - auto ref = resource::device_memory_resource::make_ref( - static_cast(holder2->get_resource())); - - void* ptr = ref.allocate(stream, 1024); - ASSERT_NE(ptr, nullptr); - ref.deallocate(stream, ptr, 1024); -} - -TEST(MDArray, DeviceResourceBridgeUnwrap) { test_device_resource_bridge_unwrap(); } - } // namespace raft diff --git a/cpp/tests/random/multi_variable_gaussian.cu b/cpp/tests/random/multi_variable_gaussian.cu index 59f611995b..d58fda18b4 100644 --- a/cpp/tests/random/multi_variable_gaussian.cu +++ b/cpp/tests/random/multi_variable_gaussian.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -278,7 +278,7 @@ class MVGMdspanTest : public ::testing::TestWithParam> { raft::device_matrix_view X_view(X_d.data(), dim, nPoints); raft::random::multi_variable_gaussian( - handle, rmm::mr::get_current_device_resource(), x_view, P_view, X_view, method); + handle, rmm::mr::get_current_device_resource_ref(), x_view, P_view, X_view, method); // saving the mean of the randoms in Rand_mean //@todo can be swapped with a API that calculates mean From c36d03b114bca8a5ed83598376980e119a6dce4e Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 15 Apr 2026 11:42:07 -0500 Subject: [PATCH 2/4] Inline upstream memory resource variable in benchmark MR composition --- cpp/bench/prims/common/benchmark.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 48f54d633e..7adbaa6476 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -33,19 +33,18 @@ namespace raft::bench { */ struct using_pool_memory_res { private: - rmm::mr::cuda_memory_resource cuda_res_{}; rmm::mr::pool_memory_resource pool_res_; cuda::mr::any_resource prev_res_; public: using_pool_memory_res(size_t initial_size, size_t max_size) - : pool_res_(cuda_res_, initial_size, max_size), + : pool_res_(rmm::mr::cuda_memory_resource{}, initial_size, max_size), prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_)) { } using_pool_memory_res() - : pool_res_(cuda_res_, rmm::percent_of_free_device_memory(50)), + : pool_res_(rmm::mr::cuda_memory_resource{}, rmm::percent_of_free_device_memory(50)), prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_)) { } From 414d6ad90337adf8884f7b353de73874785e9f62 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 16 Apr 2026 11:00:57 -0500 Subject: [PATCH 3/4] Replace deprecated set_current_device_resource_ref with set_current_device_resource --- cpp/bench/prims/common/benchmark.hpp | 6 +++--- cpp/bench/prims/matrix/gather.cu | 4 ++-- cpp/bench/prims/random/subsample.cu | 4 ++-- cpp/include/raft/core/device_resources_manager.hpp | 2 +- cpp/include/raft/util/memory_tracking_resources.hpp | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 7adbaa6476..de2878ec54 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -39,17 +39,17 @@ struct using_pool_memory_res { public: using_pool_memory_res(size_t initial_size, size_t max_size) : pool_res_(rmm::mr::cuda_memory_resource{}, initial_size, max_size), - prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_)) + prev_res_(rmm::mr::set_current_device_resource(pool_res_)) { } using_pool_memory_res() : pool_res_(rmm::mr::cuda_memory_resource{}, rmm::percent_of_free_device_memory(50)), - prev_res_(rmm::mr::set_current_device_resource_ref(pool_res_)) + prev_res_(rmm::mr::set_current_device_resource(pool_res_)) { } - ~using_pool_memory_res() { rmm::mr::set_current_device_resource_ref(prev_res_); } + ~using_pool_memory_res() { rmm::mr::set_current_device_resource(prev_res_); } }; /** diff --git a/cpp/bench/prims/matrix/gather.cu b/cpp/bench/prims/matrix/gather.cu index a182fb5788..77221bf5de 100644 --- a/cpp/bench/prims/matrix/gather.cu +++ b/cpp/bench/prims/matrix/gather.cu @@ -36,7 +36,7 @@ struct Gather : public fixture { Gather(const GatherParams& p) : params(p), pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * (1ULL << 30)), - prev_res_(rmm::mr::set_current_device_resource_ref(pool_mr)), + prev_res_(rmm::mr::set_current_device_resource(pool_mr)), matrix(this->handle), map(this->handle), out(this->handle), @@ -45,7 +45,7 @@ struct Gather : public fixture { { } - ~Gather() { rmm::mr::set_current_device_resource_ref(prev_res_); } + ~Gather() { rmm::mr::set_current_device_resource(prev_res_); } void allocate_data(const ::benchmark::State& state) override { diff --git a/cpp/bench/prims/random/subsample.cu b/cpp/bench/prims/random/subsample.cu index 21f9efacf1..1aa0a3b827 100644 --- a/cpp/bench/prims/random/subsample.cu +++ b/cpp/bench/prims/random/subsample.cu @@ -51,14 +51,14 @@ struct sample : public fixture { sample(const sample_inputs& p) : params(p), pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * GiB), - prev_mr(rmm::mr::set_current_device_resource_ref(pool_mr)), + prev_mr(rmm::mr::set_current_device_resource(pool_mr)), in(make_device_vector(res, p.n_samples)), out(make_device_vector(res, p.n_train)) { raft::random::RngState r(123456ULL); } - ~sample() { rmm::mr::set_current_device_resource_ref(prev_mr); } + ~sample() { rmm::mr::set_current_device_resource(prev_mr); } void run_benchmark(::benchmark::State& state) override { std::ostringstream label_stream; diff --git a/cpp/include/raft/core/device_resources_manager.hpp b/cpp/include/raft/core/device_resources_manager.hpp index 24f0a7c365..5a6ab59826 100644 --- a/cpp/include/raft/core/device_resources_manager.hpp +++ b/cpp/include/raft/core/device_resources_manager.hpp @@ -156,7 +156,7 @@ struct device_resources_manager { upstream, params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), params.max_mem_pool_size); - rmm::mr::set_current_device_resource_ref(*result); + rmm::mr::set_current_device_resource(*result); } return result; }()} diff --git a/cpp/include/raft/util/memory_tracking_resources.hpp b/cpp/include/raft/util/memory_tracking_resources.hpp index 4891b25aa7..306c10ce1e 100644 --- a/cpp/include/raft/util/memory_tracking_resources.hpp +++ b/cpp/include/raft/util/memory_tracking_resources.hpp @@ -107,7 +107,7 @@ class memory_tracking_resources : public resources { { report_.stop(); raft::mr::set_default_host_resource(old_host_ref_); - rmm::mr::set_current_device_resource_ref(old_device_ref_); + rmm::mr::set_current_device_resource(old_device_ref_); } memory_tracking_resources(memory_tracking_resources const&) = delete; @@ -198,7 +198,7 @@ class memory_tracking_resources : public resources { device_stats_t sa{old_device_ref_}; report_.register_source("device", sa.get_stats()); device_adaptor_ = std::make_unique(std::move(sa), report_.get_notifier()); - rmm::mr::set_current_device_resource_ref(*device_adaptor_); + rmm::mr::set_current_device_resource(*device_adaptor_); } // --- Workspace (track upstream to preserve limiting_resource_adaptor) --- From f05df96d0fca045d404c53a475a19dc884114994 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Mon, 20 Apr 2026 23:01:54 -0500 Subject: [PATCH 4/4] Re-add workspace resource APIs migrated to CCCL types Restore workspace_resource constructor parameters and copy constructors on device_resources and handle_t using raft::mr::device_resource instead of shared_ptr. Re-add workspace MR plumbing in device_resources_manager (storage, setter, getter, pass-through). Add TODO for reinstating the dynamic_cast guard once CCCL exposes resource_cast. --- cpp/include/raft/core/device_resources.hpp | 19 ++++- .../raft/core/device_resources_manager.hpp | 72 ++++++++++++++++++- cpp/include/raft/core/handle.hpp | 12 +++- cpp/tests/core/device_resources_manager.cpp | 14 ++++ 4 files changed, 112 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index a672a34d3b..753ac769d3 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -51,6 +51,14 @@ namespace raft { */ class device_resources : public resources { public: + device_resources(const device_resources& handle, + raft::mr::device_resource workspace_resource, + std::optional allocation_limit = std::nullopt) + : resources{handle} + { + resource::set_workspace_resource(*this, std::move(workspace_resource), allocation_limit); + } + device_resources(const device_resources& handle) : resources{handle} {} device_resources(device_resources&&) = delete; device_resources& operator=(device_resources&&) = delete; @@ -61,9 +69,15 @@ class device_resources : public resources { * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + * @param[in] workspace_resource an optional resource used by some functions for allocating + * temporary workspaces. + * @param[in] allocation_limit the total amount of memory in bytes available to the temporary + * workspace resources. */ device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) + std::shared_ptr stream_pool = {nullptr}, + std::optional workspace_resource = std::nullopt, + std::optional allocation_limit = std::nullopt) : resources{} { resources::add_resource_factory(std::make_shared()); @@ -71,6 +85,9 @@ class device_resources : public resources { std::make_shared(stream_view)); resources::add_resource_factory( std::make_shared(stream_pool)); + if (workspace_resource) { + resource::set_workspace_resource(*this, std::move(*workspace_resource), allocation_limit); + } } /** Destroys all held-up resources */ diff --git a/cpp/include/raft/core/device_resources_manager.hpp b/cpp/include/raft/core/device_resources_manager.hpp index 5a6ab59826..b400546281 100644 --- a/cpp/include/raft/core/device_resources_manager.hpp +++ b/cpp/include/raft/core/device_resources_manager.hpp @@ -115,6 +115,10 @@ struct device_resources_manager { std::optional max_mem_pool_size{std::size_t{}}; // Limit on workspace memory for the returned device_resources object std::optional workspace_allocation_limit{std::nullopt}; + // Optional specification of separate workspace memory resources for each + // device. The integer in each pair indicates the device for this memory + // resource. + std::vector> workspace_mrs{}; } params_; // This struct stores the underlying resources to be shared among @@ -151,6 +155,9 @@ struct device_resources_manager { // If max_mem_pool_size is nullopt or non-zero, create a pool memory // resource if (params.max_mem_pool_size.value_or(1) != 0) { + // TODO: reinstate the dynamic_cast guard that + // skipped pool creation when a non-default resource was already set, + // once CCCL exposes resource_cast or an equivalent type-query. auto upstream = rmm::mr::get_current_device_resource_ref(); result.emplace( upstream, @@ -159,7 +166,16 @@ struct device_resources_manager { rmm::mr::set_current_device_resource(*result); } return result; - }()} + }()}, + workspace_mr_{[¶ms, this]() { + auto result = std::optional{}; + auto iter = std::find_if(std::begin(params.workspace_mrs), + std::end(params.workspace_mrs), + [this](auto&& pair) { return pair.second == device_id_; }); + if (iter != std::end(params.workspace_mrs)) { result = iter->first; } + return result; + }()}, + workspace_allocation_limit_{params.workspace_allocation_limit} { } @@ -196,12 +212,22 @@ struct device_resources_manager { } // Return the pool memory resource created for this device by the manager (if any) [[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; } + // Return the RAFT workspace allocation limit that will be used by + // `device_resources` returned from this manager + [[nodiscard]] auto get_workspace_allocation_limit() const + { + return workspace_allocation_limit_; + } + // Return the workspace memory resource for this device (if any) + [[nodiscard]] auto& get_workspace_memory_resource() { return workspace_mr_; } private: int device_id_; std::unique_ptr streams_; std::vector> pools_; std::optional pool_mr_; + std::optional workspace_mr_; + std::optional workspace_allocation_limit_{std::nullopt}; }; // Mutex used to lock access to shared data until after the first @@ -255,7 +281,10 @@ struct device_resources_manager { auto scoped_device = device_setter(device_id); // Build the device_resources object for this thread out of shared // components - thread_resources[device_id].emplace(component_iter->get_stream(), component_iter->get_pool()); + thread_resources[device_id].emplace(component_iter->get_stream(), + component_iter->get_pool(), + component_iter->get_workspace_memory_resource(), + component_iter->get_workspace_allocation_limit()); } return thread_resources[device_id].value(); @@ -335,6 +364,26 @@ struct device_resources_manager { } } + // Thread-safe setter for workspace memory resources + void set_workspace_memory_resource_(raft::mr::device_resource mr, int device_id) + { + auto lock = get_lock(); + if (params_finalized_) { + RAFT_LOG_WARN( + "Attempted to set device_resources_manager properties after resources have already been " + "retrieved"); + } else { + auto iter = std::find_if(std::begin(params_.workspace_mrs), + std::end(params_.workspace_mrs), + [device_id](auto&& pair) { return pair.second == device_id; }); + if (iter != std::end(params_.workspace_mrs)) { + iter->first = std::move(mr); + } else { + params_.workspace_mrs.emplace_back(std::move(mr), device_id); + } + } + } + // Retrieve the instance of this singleton static auto& get_manager() { @@ -484,5 +533,24 @@ struct device_resources_manager { set_init_mem_pool_size(init_mem); set_max_mem_pool_size(max_mem); } + + /** + * @brief Set the workspace memory resource to be used on a specific device + * + * RAFT device_resources objects can be built with a separate memory + * resource for allocating temporary workspaces. If a memory + * resource is provided by this setter, it will be used as the + * workspace memory resource for all `device_resources` returned for the + * indicated device. + * + * If called after the first call to + * `raft::device_resources_manager::get_device_resources`, no change will be made, + * and a warning will be emitted. + */ + static void set_workspace_memory_resource(raft::mr::device_resource mr, + int device_id = device_setter::get_current_device()) + { + get_manager().set_workspace_memory_resource_(std::move(mr), device_id); + } }; } // namespace raft diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 85486ab101..ac2b5705b6 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -21,6 +21,11 @@ namespace raft { */ class handle_t : public raft::device_resources { public: + handle_t(const handle_t& handle, raft::mr::device_resource workspace_resource) + : device_resources(handle, std::move(workspace_resource)) + { + } + handle_t(const handle_t& handle) : device_resources{handle} {} handle_t(handle_t&&) = delete; @@ -32,10 +37,13 @@ class handle_t : public raft::device_resources { * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + * @param[in] workspace_resource an optional resource used by some functions for allocating + * temporary workspaces. */ handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) - : device_resources{stream_view, stream_pool} + std::shared_ptr stream_pool = {nullptr}, + std::optional workspace_resource = std::nullopt) + : device_resources{stream_view, stream_pool, std::move(workspace_resource)} { } diff --git a/cpp/tests/core/device_resources_manager.cpp b/cpp/tests/core/device_resources_manager.cpp index fa833547bc..b38a0265ff 100644 --- a/cpp/tests/core/device_resources_manager.cpp +++ b/cpp/tests/core/device_resources_manager.cpp @@ -7,6 +7,7 @@ #include #include +#include #include @@ -35,6 +36,7 @@ TEST(DeviceResourcesManager, ObeysSetters) auto pools_per_device = 3; auto streams_per_pool = 7; auto workspace_limit = 2048; + auto workspace_init = 1024; device_resources_manager::set_streams_per_device(streams_per_device); device_resources_manager::set_stream_pools_per_device(pools_per_device, streams_per_pool); device_resources_manager::set_mem_pool(); @@ -46,6 +48,15 @@ TEST(DeviceResourcesManager, ObeysSetters) // Provide lock for counting unique objects auto mtx = std::mutex{}; + for (auto i = std::size_t{}; i < devices.size(); ++i) { + auto scoped_device = device_setter{devices[i]}; + auto upstream = rmm::mr::get_current_device_resource_ref(); + device_resources_manager::set_workspace_memory_resource( + raft::mr::device_resource{ + rmm::mr::pool_memory_resource(upstream, workspace_init, workspace_limit)}, + devices[i]); + } + // Suppress the many warnings from testing use of setters after initial // get_device_resources call auto scoped_log_level = @@ -83,6 +94,9 @@ TEST(DeviceResourcesManager, ObeysSetters) device_resources_manager::set_stream_pools_per_device(pools_per_device - 1); device_resources_manager::set_mem_pool(); device_resources_manager::set_workspace_allocation_limit(1024); + device_resources_manager::set_workspace_memory_resource( + raft::mr::device_resource{rmm::mr::get_current_device_resource_ref()}, + devices[i % devices.size()]); } EXPECT_EQ(streams_per_device, unique_streams[devices[0]].size());