Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/mr/cuda_memory_resource.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>

Expand All @@ -33,26 +33,23 @@ namespace raft::bench {
*/
struct using_pool_memory_res {
private:
rmm::mr::device_memory_resource* orig_res_;
rmm::mr::cuda_memory_resource cuda_res_{};
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;
rmm::mr::pool_memory_resource pool_res_;
cuda::mr::any_resource<cuda::mr::device_accessible> prev_res_;

public:
using_pool_memory_res(size_t initial_size, size_t max_size)
: orig_res_(rmm::mr::get_current_device_resource()),
pool_res_(&cuda_res_, initial_size, max_size)
: pool_res_(rmm::mr::cuda_memory_resource{}, initial_size, max_size),
prev_res_(rmm::mr::set_current_device_resource(pool_res_))
{
rmm::mr::set_current_device_resource(&pool_res_);
}

using_pool_memory_res()
: orig_res_(rmm::mr::get_current_device_resource()),
pool_res_(&cuda_res_, rmm::percent_of_free_device_memory(50))
: pool_res_(rmm::mr::cuda_memory_resource{}, rmm::percent_of_free_device_memory(50)),
prev_res_(rmm::mr::set_current_device_resource(pool_res_))
{
rmm::mr::set_current_device_resource(&pool_res_);
}

~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); }
~using_pool_memory_res() { rmm::mr::set_current_device_resource(prev_res_); }
};

/**
Expand Down
15 changes: 7 additions & 8 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -13,7 +13,7 @@
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>

namespace raft::bench::matrix {
Expand All @@ -35,18 +35,17 @@ template <typename T, typename MapT, typename IdxT, bool Conditional = false>
struct Gather : public fixture {
Gather(const GatherParams<IdxT>& p)
: params(p),
old_mr(rmm::mr::get_current_device_resource()),
pool_mr(rmm::mr::get_current_device_resource(), 2 * (1ULL << 30)),
pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * (1ULL << 30)),
prev_res_(rmm::mr::set_current_device_resource(pool_mr)),
matrix(this->handle),
map(this->handle),
out(this->handle),
stencil(this->handle),
matrix_h(this->handle)
{
rmm::mr::set_current_device_resource(&pool_mr);
}

~Gather() { rmm::mr::set_current_device_resource(old_mr); }
~Gather() { rmm::mr::set_current_device_resource(prev_res_); }

void allocate_data(const ::benchmark::State& state) override
{
Expand Down Expand Up @@ -107,8 +106,8 @@ struct Gather : public fixture {

private:
GatherParams<IdxT> params;
rmm::mr::device_memory_resource* old_mr;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
rmm::mr::pool_memory_resource pool_mr;
cuda::mr::any_resource<cuda::mr::device_accessible> prev_res_;
raft::device_matrix<T, IdxT> matrix, out;
raft::host_matrix<T, IdxT> matrix_h;
raft::device_vector<T, IdxT> stencil;
Expand Down
11 changes: 5 additions & 6 deletions cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,15 @@ template <typename T>
struct sample : public fixture {
sample(const sample_inputs& p)
: params(p),
old_mr(rmm::mr::get_current_device_resource()),
pool_mr(rmm::mr::get_current_device_resource(), 2 * GiB),
pool_mr(rmm::mr::get_current_device_resource_ref(), 2 * GiB),
prev_mr(rmm::mr::set_current_device_resource(pool_mr)),
in(make_device_vector<T, int64_t>(res, p.n_samples)),
out(make_device_vector<T, int64_t>(res, p.n_train))
{
rmm::mr::set_current_device_resource(&pool_mr);
raft::random::RngState r(123456ULL);
}

~sample() { rmm::mr::set_current_device_resource(old_mr); }
~sample() { rmm::mr::set_current_device_resource(prev_mr); }
void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
Expand All @@ -81,8 +80,8 @@ struct sample : public fixture {
private:
float GiB = 1073741824.0f;
raft::device_resources res;
rmm::mr::device_memory_resource* old_mr;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
rmm::mr::pool_memory_resource pool_mr;
cuda::mr::any_resource<cuda::mr::device_accessible> prev_mr;
sample_inputs params;
raft::device_vector<T, int64_t> out, in;
}; // struct sample
Expand Down
17 changes: 8 additions & 9 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -26,7 +26,7 @@

#include <rmm/cuda_stream_pool.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda_runtime.h>

Expand All @@ -52,12 +52,11 @@ namespace raft {
class device_resources : public resources {
public:
device_resources(const device_resources& handle,
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource,
raft::mr::device_resource workspace_resource,
std::optional<std::size_t> allocation_limit = std::nullopt)
: resources{handle}
{
// replace the resource factory for the workspace_resources
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
resource::set_workspace_resource(*this, std::move(workspace_resource), allocation_limit);
}

device_resources(const device_resources& handle) : resources{handle} {}
Expand All @@ -77,8 +76,8 @@ class device_resources : public resources {
*/
device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr},
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource = {nullptr},
Comment thread
bdice marked this conversation as resolved.
std::optional<std::size_t> allocation_limit = std::nullopt)
std::optional<raft::mr::device_resource> workspace_resource = std::nullopt,
std::optional<std::size_t> allocation_limit = std::nullopt)
: resources{}
{
resources::add_resource_factory(std::make_shared<resource::device_id_resource_factory>());
Expand All @@ -87,7 +86,7 @@ class device_resources : public resources {
resources::add_resource_factory(
std::make_shared<resource::cuda_stream_pool_resource_factory>(stream_pool));
if (workspace_resource) {
resource::set_workspace_resource(*this, workspace_resource, allocation_limit);
resource::set_workspace_resource(*this, std::move(*workspace_resource), allocation_limit);
}
}

Expand Down Expand Up @@ -214,7 +213,7 @@ class device_resources : public resources {
return resource::get_subcomm(*this, key);
}

rmm::mr::device_memory_resource* get_workspace_resource() const
rmm::mr::limiting_resource_adaptor* get_workspace_resource() const
{
return resource::get_workspace_resource(*this);
}
Expand Down
66 changes: 28 additions & 38 deletions cpp/include/raft/core/device_resources_manager.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -12,6 +12,7 @@
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/mr/cuda_memory_resource.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>

#include <algorithm>
#include <memory>
Expand Down Expand Up @@ -117,9 +118,7 @@ struct device_resources_manager {
// Optional specification of separate workspace memory resources for each
// device. The integer in each pair indicates the device for this memory
// resource.
std::vector<std::pair<std::shared_ptr<rmm::mr::device_memory_resource>, int>> workspace_mrs{};

auto get_workspace_memory_resource(int device_id) {}
std::vector<std::pair<raft::mr::device_resource, int>> workspace_mrs{};
} params_;

// This struct stores the underlying resources to be shared among
Expand Down Expand Up @@ -152,36 +151,31 @@ struct device_resources_manager {
}()},
pool_mr_{[&params, this]() {
auto scoped_device = device_setter{device_id_};
auto result =
std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>{nullptr};
auto result = std::optional<rmm::mr::pool_memory_resource>{};
// If max_mem_pool_size is nullopt or non-zero, create a pool memory
// resource
if (params.max_mem_pool_size.value_or(1) != 0) {
auto* upstream =
dynamic_cast<rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource());
if (upstream != nullptr) {
result =
std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>(
upstream,
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
params.max_mem_pool_size);
rmm::mr::set_current_device_resource(result.get());
} else {
RAFT_LOG_WARN(
"Pool allocation requested, but other memory resource has already been set and "
"will not be overwritten");
}
// TODO: reinstate the dynamic_cast<cuda_memory_resource*> guard that
// skipped pool creation when a non-default resource was already set,
// once CCCL exposes resource_cast or an equivalent type-query.
Comment on lines +158 to +160
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have marked this as a post-task for me to complete. There is a similar need in cuOpt. CCCL recently added the resource_cast feature to help with this, but I haven't had a chance to update RAPIDS to that version of CCCL yet. It should be possible to fix this in 1-2 weeks.

auto upstream = rmm::mr::get_current_device_resource_ref();
result.emplace(
upstream,
params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)),
params.max_mem_pool_size);
rmm::mr::set_current_device_resource(*result);
}
return result;
}()},
workspace_mr_{[&params, this]() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-add

auto result = std::shared_ptr<rmm::mr::device_memory_resource>{nullptr};
auto result = std::optional<raft::mr::device_resource>{};
auto iter = std::find_if(std::begin(params.workspace_mrs),
std::end(params.workspace_mrs),
[this](auto&& pair) { return pair.second == device_id_; });
if (iter != std::end(params.workspace_mrs)) { result = iter->first; }
return result;
}()}
}()},
workspace_allocation_limit_{params.workspace_allocation_limit}
{
}

Expand Down Expand Up @@ -216,26 +210,23 @@ struct device_resources_manager {
if (pool_count() != 0) { result = pools_[get_thread_id() % pool_count()]; }
return result;
}
// Return a (possibly null) shared_ptr to the pool memory resource
// created for this device by the manager
[[nodiscard]] auto get_pool_memory_resource() const { return pool_mr_; }
// Return the pool memory resource created for this device by the manager (if any)
[[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; }
// Return the RAFT workspace allocation limit that will be used by
// `device_resources` returned from this manager
[[nodiscard]] auto get_workspace_allocation_limit() const
{
return workspace_allocation_limit_;
}
// Return a (possibly null) shared_ptr to the memory resource that will
// be used for workspace allocations by `device_resources` returned from
// this manager
[[nodiscard]] auto get_workspace_memory_resource() { return workspace_mr_; }
// Return the workspace memory resource for this device (if any)
[[nodiscard]] auto& get_workspace_memory_resource() { return workspace_mr_; }

private:
int device_id_;
std::unique_ptr<rmm::cuda_stream_pool> streams_;
std::vector<std::shared_ptr<rmm::cuda_stream_pool>> pools_;
std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>> pool_mr_;
std::shared_ptr<rmm::mr::device_memory_resource> workspace_mr_;
std::optional<rmm::mr::pool_memory_resource> pool_mr_;
std::optional<raft::mr::device_resource> workspace_mr_;
std::optional<std::size_t> workspace_allocation_limit_{std::nullopt};
};

Expand Down Expand Up @@ -374,8 +365,7 @@ struct device_resources_manager {
}

// Thread-safe setter for workspace memory resources
void set_workspace_memory_resource_(std::shared_ptr<rmm::mr::device_memory_resource> mr,
int device_id)
void set_workspace_memory_resource_(raft::mr::device_resource mr, int device_id)
{
auto lock = get_lock();
if (params_finalized_) {
Expand All @@ -387,9 +377,9 @@ struct device_resources_manager {
std::end(params_.workspace_mrs),
[device_id](auto&& pair) { return pair.second == device_id; });
if (iter != std::end(params_.workspace_mrs)) {
iter->first = mr;
iter->first = std::move(mr);
} else {
params_.workspace_mrs.emplace_back(mr, device_id);
params_.workspace_mrs.emplace_back(std::move(mr), device_id);
}
}
}
Expand Down Expand Up @@ -548,7 +538,7 @@ struct device_resources_manager {
* @brief Set the workspace memory resource to be used on a specific device
*
* RAFT device_resources objects can be built with a separate memory
* resource for allocating temporary workspaces. If a (non-nullptr) memory
* resource for allocating temporary workspaces. If a memory
* resource is provided by this setter, it will be used as the
* workspace memory resource for all `device_resources` returned for the
* indicated device.
Expand All @@ -557,10 +547,10 @@ struct device_resources_manager {
* `raft::device_resources_manager::get_device_resources`, no change will be made,
* and a warning will be emitted.
*/
static void set_workspace_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
static void set_workspace_memory_resource(raft::mr::device_resource mr,
int device_id = device_setter::get_current_device())
{
get_manager().set_workspace_memory_resource_(mr, device_id);
get_manager().set_workspace_memory_resource_(std::move(mr), device_id);
}
};
} // namespace raft
11 changes: 4 additions & 7 deletions cpp/include/raft/core/device_resources_snmg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <raft/core/resource/resource_types.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>

Expand Down Expand Up @@ -105,10 +104,9 @@ class device_resources_snmg : public device_resources {
int device_id = raft::resource::get_device_id(dev_res);
pool_device_ids_.push_back(device_id);

per_device_pools_.push_back(
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
rmm::mr::get_current_device_resource_ref(),
rmm::percent_of_free_device_memory(percent_of_free_memory)));
per_device_pools_.push_back(std::make_unique<rmm::mr::pool_memory_resource>(
rmm::mr::get_current_device_resource_ref(),
rmm::percent_of_free_device_memory(percent_of_free_memory)));
rmm::mr::set_per_device_resource_ref(rmm::cuda_device_id{device_id},
*per_device_pools_.back());
}
Expand Down Expand Up @@ -151,8 +149,7 @@ class device_resources_snmg : public device_resources {
}
}
int main_gpu_id_;
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>>
per_device_pools_;
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource>> per_device_pools_;
std::vector<int> pool_device_ids_;
}; // class device_resources_snmg

Expand Down
11 changes: 5 additions & 6 deletions cpp/include/raft/core/handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2019-2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -21,9 +21,8 @@ namespace raft {
*/
class handle_t : public raft::device_resources {
public:
handle_t(const handle_t& handle,
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource)
: device_resources(handle, workspace_resource)
handle_t(const handle_t& handle, raft::mr::device_resource workspace_resource)
: device_resources(handle, std::move(workspace_resource))
{
}

Expand All @@ -43,8 +42,8 @@ class handle_t : public raft::device_resources {
*/
handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread,
std::shared_ptr<rmm::cuda_stream_pool> stream_pool = {nullptr},
std::shared_ptr<rmm::mr::device_memory_resource> workspace_resource = {nullptr})
: device_resources{stream_view, stream_pool, workspace_resource}
std::optional<raft::mr::device_resource> workspace_resource = std::nullopt)
: device_resources{stream_view, stream_pool, std::move(workspace_resource)}
{
}

Expand Down
Loading
Loading