-
Notifications
You must be signed in to change notification settings - Fork 236
Migrate RMM usage to CCCL MR design #2996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c67f9fa
d0232d3
c36d03b
414d6ad
f05df96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
|
|
@@ -12,6 +12,7 @@ | |
| #include <rmm/cuda_stream_pool.hpp> | ||
| #include <rmm/mr/cuda_memory_resource.hpp> | ||
| #include <rmm/mr/per_device_resource.hpp> | ||
| #include <rmm/mr/pool_memory_resource.hpp> | ||
|
|
||
| #include <algorithm> | ||
| #include <memory> | ||
|
|
@@ -117,9 +118,7 @@ struct device_resources_manager { | |
| // Optional specification of separate workspace memory resources for each | ||
| // device. The integer in each pair indicates the device for this memory | ||
| // resource. | ||
| std::vector<std::pair<std::shared_ptr<rmm::mr::device_memory_resource>, int>> workspace_mrs{}; | ||
|
|
||
| auto get_workspace_memory_resource(int device_id) {} | ||
| std::vector<std::pair<raft::mr::device_resource, int>> workspace_mrs{}; | ||
| } params_; | ||
|
|
||
| // This struct stores the underlying resources to be shared among | ||
|
|
@@ -152,36 +151,31 @@ struct device_resources_manager { | |
| }()}, | ||
| pool_mr_{[¶ms, this]() { | ||
| auto scoped_device = device_setter{device_id_}; | ||
| auto result = | ||
| std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>{nullptr}; | ||
| auto result = std::optional<rmm::mr::pool_memory_resource>{}; | ||
| // If max_mem_pool_size is nullopt or non-zero, create a pool memory | ||
| // resource | ||
| if (params.max_mem_pool_size.value_or(1) != 0) { | ||
| auto* upstream = | ||
| dynamic_cast<rmm::mr::cuda_memory_resource*>(rmm::mr::get_current_device_resource()); | ||
| if (upstream != nullptr) { | ||
| result = | ||
| std::make_shared<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>>( | ||
| upstream, | ||
| params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), | ||
| params.max_mem_pool_size); | ||
| rmm::mr::set_current_device_resource(result.get()); | ||
| } else { | ||
| RAFT_LOG_WARN( | ||
| "Pool allocation requested, but other memory resource has already been set and " | ||
| "will not be overwritten"); | ||
| } | ||
| // TODO: reinstate the dynamic_cast<cuda_memory_resource*> guard that | ||
| // skipped pool creation when a non-default resource was already set, | ||
| // once CCCL exposes resource_cast or an equivalent type-query. | ||
|
Comment on lines
+158
to
+160
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have marked this as a post-task for me to complete. There is a similar need in cuOpt. CCCL recently added the |
||
| auto upstream = rmm::mr::get_current_device_resource_ref(); | ||
| result.emplace( | ||
| upstream, | ||
| params.init_mem_pool_size.value_or(rmm::percent_of_free_device_memory(50)), | ||
| params.max_mem_pool_size); | ||
| rmm::mr::set_current_device_resource(*result); | ||
| } | ||
| return result; | ||
| }()}, | ||
| workspace_mr_{[¶ms, this]() { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re-add |
||
| auto result = std::shared_ptr<rmm::mr::device_memory_resource>{nullptr}; | ||
| auto result = std::optional<raft::mr::device_resource>{}; | ||
| auto iter = std::find_if(std::begin(params.workspace_mrs), | ||
| std::end(params.workspace_mrs), | ||
| [this](auto&& pair) { return pair.second == device_id_; }); | ||
| if (iter != std::end(params.workspace_mrs)) { result = iter->first; } | ||
| return result; | ||
| }()} | ||
| }()}, | ||
| workspace_allocation_limit_{params.workspace_allocation_limit} | ||
| { | ||
| } | ||
|
|
||
|
|
@@ -216,26 +210,23 @@ struct device_resources_manager { | |
| if (pool_count() != 0) { result = pools_[get_thread_id() % pool_count()]; } | ||
| return result; | ||
| } | ||
| // Return a (possibly null) shared_ptr to the pool memory resource | ||
| // created for this device by the manager | ||
| [[nodiscard]] auto get_pool_memory_resource() const { return pool_mr_; } | ||
| // Return the pool memory resource created for this device by the manager (if any) | ||
| [[nodiscard]] auto& get_pool_memory_resource() { return pool_mr_; } | ||
| // Return the RAFT workspace allocation limit that will be used by | ||
| // `device_resources` returned from this manager | ||
| [[nodiscard]] auto get_workspace_allocation_limit() const | ||
| { | ||
| return workspace_allocation_limit_; | ||
| } | ||
| // Return a (possibly null) shared_ptr to the memory resource that will | ||
| // be used for workspace allocations by `device_resources` returned from | ||
| // this manager | ||
| [[nodiscard]] auto get_workspace_memory_resource() { return workspace_mr_; } | ||
| // Return the workspace memory resource for this device (if any) | ||
| [[nodiscard]] auto& get_workspace_memory_resource() { return workspace_mr_; } | ||
|
|
||
| private: | ||
| int device_id_; | ||
| std::unique_ptr<rmm::cuda_stream_pool> streams_; | ||
| std::vector<std::shared_ptr<rmm::cuda_stream_pool>> pools_; | ||
| std::shared_ptr<rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource>> pool_mr_; | ||
| std::shared_ptr<rmm::mr::device_memory_resource> workspace_mr_; | ||
| std::optional<rmm::mr::pool_memory_resource> pool_mr_; | ||
| std::optional<raft::mr::device_resource> workspace_mr_; | ||
| std::optional<std::size_t> workspace_allocation_limit_{std::nullopt}; | ||
| }; | ||
|
|
||
|
|
@@ -374,8 +365,7 @@ struct device_resources_manager { | |
| } | ||
|
|
||
| // Thread-safe setter for workspace memory resources | ||
| void set_workspace_memory_resource_(std::shared_ptr<rmm::mr::device_memory_resource> mr, | ||
| int device_id) | ||
| void set_workspace_memory_resource_(raft::mr::device_resource mr, int device_id) | ||
| { | ||
| auto lock = get_lock(); | ||
| if (params_finalized_) { | ||
|
|
@@ -387,9 +377,9 @@ struct device_resources_manager { | |
| std::end(params_.workspace_mrs), | ||
| [device_id](auto&& pair) { return pair.second == device_id; }); | ||
| if (iter != std::end(params_.workspace_mrs)) { | ||
| iter->first = mr; | ||
| iter->first = std::move(mr); | ||
| } else { | ||
| params_.workspace_mrs.emplace_back(mr, device_id); | ||
| params_.workspace_mrs.emplace_back(std::move(mr), device_id); | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -548,7 +538,7 @@ struct device_resources_manager { | |
| * @brief Set the workspace memory resource to be used on a specific device | ||
| * | ||
| * RAFT device_resources objects can be built with a separate memory | ||
| * resource for allocating temporary workspaces. If a (non-nullptr) memory | ||
| * resource for allocating temporary workspaces. If a memory | ||
| * resource is provided by this setter, it will be used as the | ||
| * workspace memory resource for all `device_resources` returned for the | ||
| * indicated device. | ||
|
|
@@ -557,10 +547,10 @@ struct device_resources_manager { | |
| * `raft::device_resources_manager::get_device_resources`, no change will be made, | ||
| * and a warning will be emitted. | ||
| */ | ||
| static void set_workspace_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr, | ||
| static void set_workspace_memory_resource(raft::mr::device_resource mr, | ||
| int device_id = device_setter::get_current_device()) | ||
| { | ||
| get_manager().set_workspace_memory_resource_(mr, device_id); | ||
| get_manager().set_workspace_memory_resource_(std::move(mr), device_id); | ||
| } | ||
| }; | ||
| } // namespace raft | ||
Uh oh!
There was an error while loading. Please reload this page.