Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/cmake/patches/faiss-1.14-cuvs-26.06.diff
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ index 548618262..3be071550 100644
- rmm::mr::device_memory_resource* current_mr =
- rmm::mr::get_per_device_resource(
+ auto current_mr =
+ rmm::mr::get_per_device_resource_ref(
+ rmm::mr::get_per_device_resource(
rmm::cuda_device_id{adjReq.device});
- p = current_mr->allocate(adjReq.stream, adjReq.size);
+ p = current_mr.allocate(adjReq.stream, adjReq.size);
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ void compute_norm(const raft::resources& handle,
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope("compute_norm");
auto stream = raft::resource::get_cuda_stream(handle);
rmm::device_uvector<MathT> mapped_dataset(
0, stream, mr.value_or(raft::resource::get_workspace_resource_ref(handle)));
0, stream, mr.value_or(raft::resource::get_workspace_resource(handle)));

const MathT* dataset_ptr = nullptr;

Expand Down Expand Up @@ -383,7 +383,7 @@ void predict(const raft::resources& handle,
auto stream = raft::resource::get_cuda_stream(handle);
raft::common::nvtx::range<cuvs::common::nvtx::domain::cuvs> fun_scope(
"predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
auto mem_res = mr.value_or(raft::resource::get_workspace_resource_ref(handle));
auto mem_res = mr.value_or(raft::resource::get_workspace_resource(handle));
auto [max_minibatch_size, _mem_per_row] = calc_minibatch_size<MathT>(
handle, n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
rmm::device_uvector<MathT> cur_dataset(
Expand Down Expand Up @@ -995,7 +995,7 @@ void build_hierarchical(const raft::resources& handle,

// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
rmm::mr::managed_memory_resource managed_memory;
rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource_ref(handle);
rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] = calc_minibatch_size<MathT>(
handle, n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void predict(const raft::resources& handle,
X.extent(0),
labels.data_handle(),
mapping_op,
raft::resource::get_workspace_resource_ref(handle));
raft::resource::get_workspace_resource(handle));
}

namespace helpers {
Expand Down Expand Up @@ -305,7 +305,7 @@ void calc_centers_and_sizes(const raft::resources& handle,
labels.data_handle(),
reset_counters,
mapping_op,
raft::resource::get_workspace_resource_ref(handle));
raft::resource::get_workspace_resource(handle));
}

} // namespace helpers
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/cluster/kmeans_balanced_build_clusters_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void build_clusters(const raft::resources& handle,
labels.data_handle(),
cluster_sizes.data_handle(),
mapping_op,
raft::resource::get_workspace_resource_ref(handle),
raft::resource::get_workspace_resource(handle),
X_norm.has_value() ? X_norm.value().data_handle() : nullptr);
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/detail/masked_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ void masked_l2_nn_impl(raft::resources const& handle,
static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block.");

// Get stream and workspace memory resource
rmm::device_async_resource_ref ws_mr = raft::resource::get_workspace_resource_ref(handle);
rmm::device_async_resource_ref ws_mr = raft::resource::get_workspace_resource(handle);
auto stream = raft::resource::get_cuda_stream(handle);

// Acquire temporary buffers and initialize to zero:
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/composite/index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void composite_index<T, IdxT, OutputIdxT>::search(
size_t buffer_size = num_queries * K * num_indices;

auto main_stream = raft::resource::get_cuda_stream(handle);
auto tmp_res = raft::resource::get_workspace_resource_ref(handle);
auto tmp_res = raft::resource::get_workspace_resource(handle);

rmm::device_uvector<out_index_type> neighbors_buffer(buffer_size, main_stream, tmp_res);
rmm::device_uvector<float> distances_buffer(buffer_size, main_stream, tmp_res);
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ struct batch_load_iterator {
{
}

/** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */
/** Convenience overload that uses `get_workspace_resource(res)` as the memory resource. */
batch_load_iterator(raft::resources const& res,
MdspanT input_view,
size_type batch_size,
Expand All @@ -876,7 +876,7 @@ struct batch_load_iterator {
input_view,
batch_size,
copy_stream,
raft::resource::get_workspace_resource_ref(res),
raft::resource::get_workspace_resource(res),
prefetch,
initialize,
host_writeback)
Expand Down Expand Up @@ -1036,7 +1036,7 @@ class batch_load_iterator_dyn {
{
}

/** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */
/** Convenience overload that uses `get_workspace_resource(res)` as the memory resource. */
batch_load_iterator_dyn(raft::resources const& res,
T* ptr,
IdxT n_rows,
Expand All @@ -1052,7 +1052,7 @@ class batch_load_iterator_dyn {
row_width,
batch_size,
copy_stream,
raft::resource::get_workspace_resource_ref(res),
raft::resource::get_workspace_resource(res),
prefetch,
initialize,
host_writeback)
Expand Down Expand Up @@ -1229,7 +1229,7 @@ auto make_batch_load_iterator(raft::resources const& res,
host_writeback);
}

/** Convenience overload that uses `get_workspace_resource_ref(res)` as the memory resource. */
/** Convenience overload that uses `get_workspace_resource(res)` as the memory resource. */
template <typename T, typename IdxT = int64_t>
auto make_batch_load_iterator(raft::resources const& res,
T const* ptr,
Expand All @@ -1247,7 +1247,7 @@ auto make_batch_load_iterator(raft::resources const& res,
row_width,
batch_size,
copy_stream,
raft::resource::get_workspace_resource_ref(res),
raft::resource::get_workspace_resource(res),
prefetch,
initialize,
host_writeback);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/add_nodes.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void add_node_core(
params.itopk_size = std::max(base_degree * 2lu, 256lu);

// Memory space for rank-based neighbor list
auto mr = raft::resource::get_workspace_resource_ref(handle);
auto mr = raft::resource::get_workspace_resource(handle);

auto neighbor_indices = raft::make_device_mdarray<IdxT, std::int64_t>(
handle, mr, raft::make_extents<std::int64_t>(max_search_batch_size, base_degree));
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1701,8 +1701,8 @@ void build_knn_graph(

// If the workspace is smaller than desired, put the I/O buffers into the large workspace.
rmm::device_async_resource_ref workspace_mr =
use_large_workspace ? raft::resource::get_large_workspace_resource_ref(res)
: raft::resource::get_workspace_resource_ref(res);
use_large_workspace ? raft::resource::get_large_workspace_resource(res)
: raft::resource::get_workspace_resource(res);
Comment on lines +1703 to +1705

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

# First, find the RAFT resource header files to understand the accessor API
fd -e hpp -e cuh -e cu -t f "resource" | grep -E "(raft|resource)" | head -20

Repository: rapidsai/cuvs

Length of output: 110


🏁 Script executed:

# Search for RAFT resource getter functions to understand the correct API
rg "get_workspace_resource" -A 2 -t cpp -t cuh -t hpp | head -50

Repository: rapidsai/cuvs

Length of output: 85


🏁 Script executed:

# Check the specific lines in cagra_build.cuh
sed -n '1700,1710p' cpp/src/neighbors/detail/cagra/cagra_build.cuh

Repository: rapidsai/cuvs

Length of output: 497


🏁 Script executed:

# Check for the same pattern in search_plan.cuh
fd "search_plan.cuh" -x cat {} | grep -n "get_workspace_resource" -A 1 -B 1

Repository: rapidsai/cuvs

Length of output: 404


🏁 Script executed:

# Check for the same pattern in utils.hpp
fd "utils.hpp" -t f -x grep -l "get_workspace_resource" {} \; 2>/dev/null | while read f; do echo "=== $f ==="; grep -n "get_workspace_resource" -A 1 -B 1 "$f"; done

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Find RAFT resource headers
fd "resource" -t f -e hpp | grep -i raft

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Search for resource getter definitions in RAFT
rg "get_workspace_resource_ref|get_large_workspace_resource_ref" -t cpp -t hpp

Repository: rapidsai/cuvs

Length of output: 85


🏁 Script executed:

# Look for resource management API in RAFT
fd "resource.hpp" -t f | head -5 | xargs ls -la 2>/dev/null

Repository: rapidsai/cuvs

Length of output: 159


🏁 Script executed:

# Search broader for resource definitions
rg "get_workspace_resource|get_large_workspace_resource" -t hpp -B 1 -A 1 | head -80

Repository: rapidsai/cuvs

Length of output: 85


🏁 Script executed:

# Check RMM device_async_resource_ref type
rg "device_async_resource_ref" -t hpp -B 2 -A 2 | head -50

Repository: rapidsai/cuvs

Length of output: 85


🏁 Script executed:

# Search without file type restrictions
find . -name "*.hpp" -o -name "*.h" -o -name "*.cuh" -o -name "*.cu" | xargs grep -l "get_workspace_resource" 2>/dev/null | head -20

Repository: rapidsai/cuvs

Length of output: 970


🏁 Script executed:

# Look for raft resource headers
find . -path "*/raft/*" -name "*resource*" -type f 2>/dev/null | head -20

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Search in the codebase for resource.hpp or similar
find . -name "resource.hpp" -o -name "resources.hpp" 2>/dev/null

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Check what RAFT version is being used
find . -name "CMakeLists.txt" -o -name "package.json" -o -name "conanfile.txt" | xargs grep -i "raft" 2>/dev/null | head -20

Repository: rapidsai/cuvs

Length of output: 1193


🏁 Script executed:

# Search for function definitions containing get_workspace_resource
grep -r "get_workspace_resource" --include="*.hpp" --include="*.h" --include="*.cuh" . 2>/dev/null | grep -E "(return|->|{)" | head -30

Repository: rapidsai/cuvs

Length of output: 81


🏁 Script executed:

# Simpler search for resource getters
grep -r "get_workspace_resource" . 2>/dev/null | grep "inline\|return" | head -20

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Look at what files include resource management
grep -r "resource::get_workspace_resource" . 2>/dev/null | head -10

Repository: rapidsai/cuvs

Length of output: 1238


🏁 Script executed:

# Check if there are RAFT headers included
grep -r "raft/resource" . 2>/dev/null | head -20

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Look at the actual failing file to see the full context
head -50 cpp/src/neighbors/detail/cagra/cagra_build.cuh | tail -30

Repository: rapidsai/cuvs

Length of output: 760


🏁 Script executed:

# Search for includes in cagra_build.cuh to understand dependencies
grep "include" cpp/src/neighbors/detail/cagra/cagra_build.cuh | head -20

Repository: rapidsai/cuvs

Length of output: 813


🏁 Script executed:

# Look for documentation or comments about resource types
grep -A 5 -B 5 "device_async_resource_ref" cpp/src/neighbors/detail/cagra/cagra_build.cuh | head -30

Repository: rapidsai/cuvs

Length of output: 531


🏁 Script executed:

# Search for any usage of _ref suffix versions in the codebase
grep -r "get_workspace_resource_ref\|get_large_workspace_resource_ref" . 2>/dev/null

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Look for RAFT resource header files in the repo
find . -path "*raft*resource*" -type f 2>/dev/null | head -20

Repository: rapidsai/cuvs

Length of output: 39


🏁 Script executed:

# Check what's included from raft in cagra_build.cuh
grep "^`#include`.*raft" cpp/src/neighbors/detail/cagra/cagra_build.cuh

Repository: rapidsai/cuvs

Length of output: 435


🏁 Script executed:

# Look for RAFT version or documentation
find . -name "*.md" -o -name "*.rst" -o -name "*.txt" | xargs grep -l "resource\|workspace" 2>/dev/null | head -10

Repository: rapidsai/cuvs

Length of output: 367


🏁 Script executed:

# Check if this code actually compiles - look for build artifacts or CI
find . -name "*.cmake" -o -name "CMakeLists.txt" 2>/dev/null | xargs grep -l "cagra_build" 2>/dev/null

Repository: rapidsai/cuvs

Length of output: 105


🌐 Web query:

RAFT library get_workspace_resource API documentation resource_ref

💡 Result:

In the NVIDIA RAFT library, get_workspace_resource and get_workspace_resource_ref are related C++ APIs used to manage temporary workspace memory [1][2]. The get_workspace_resource API is defined in the raft::resource namespace [1][3]. It returns a pointer to an rmm::mr::limiting_resource_adaptor, which is used for temporary workspace allocations [1][2]. The function signature is: inline rmm::mr::limiting_resource_adaptor *get_workspace_resource(resources const &res) The documentation explicitly recommends using get_workspace_resource_ref for performing allocations instead of raw pointer access [1]. This function returns an rmm::device_async_resource_ref, which provides a safer, more modern interface for managing device memory resources within the RAFT framework [1]. Key points regarding these APIs: - Purpose: They allow RAFT algorithms and primitives to access a temporary workspace resource, which is managed via the raft::resources object [1][2][3]. - Resource Management: If the workspace resource has not been initialized on the provided resources object, calling these functions will attempt to load or populate it [1][2]. - Related Functions: RAFT also provides functions for workspace accounting, such as get_workspace_total_bytes, get_workspace_used_bytes, and get_workspace_free_bytes [1][2]. - Configuration: You can configure the workspace resource using set_workspace_resource, set_workspace_to_pool_resource, or set_workspace_to_global_resource [1][2][3]. For the most up-to-date information, you should refer to the official RAFT documentation for the version you are using, as these APIs are part of the core resources management layer [1][4].

Citations:


🌐 Web query:

rmm device_async_resource_ref RAFT workspace resource getter

💡 Result:

In RAFT, the “workspace resource getter” returns an RMM non-owning async device memory resource reference: - get_workspace_resource_ref(resources const &res) returns rmm::device_async_resource_ref, described as “Get the workspace as a non-owning device_async_resource_ref.” [1] What rmm::device_async_resource_ref is (RMM/CCCL terminology): - rmm::device_async_resource_ref is an alias for a cuda::mr::async_resource_ref with the cuda::mr::device_accessible property. [2] So, for your query: - rmm::device_async_resource_ref + RAFT workspace resource getter refers to raft::core::resources “get_workspace_resource_ref”, which yields a non-owning workspace device async memory resource reference of type rmm::device_async_resource_ref. [1]

Citations:


🌐 Web query:

RAFT github rapidsai resource management get_workspace_resource_ref

💡 Result:

In the RAPIDS RAFT library, get_workspace_resource_ref is a C++ function used to retrieve the workspace as a non-owning device_async_resource_ref [1][2]. It is defined within the raft::resource namespace and is designed for memory resource management in CUDA-accelerated applications [1][2]. Key details regarding this function: Functionality: It provides a non-owning reference to the workspace device memory resource managed by the raft::resources object [1][2]. Usage: It is the preferred method for allocations, as opposed to older pointer-based APIs like get_workspace_resource, which are being phased out in favor of reference-based alternatives to align with modern RMM (RAPIDS Memory Manager) and CCCL memory resource patterns [1][3][4][2]. Related APIs: For accounting queries, users should utilize get_workspace_total_bytes, get_workspace_used_bytes, and get_workspace_free_bytes [1][2]. Similarly, get_large_workspace_resource_ref exists for managing larger workspace resources [1][2]. This shift to reference-based APIs is part of a broader migration within the RAFT codebase to improve memory resource safety and consistency [3][4].

Citations:


🏁 Script executed:

# Verify search_plan.cuh has the same issue
grep -n "get_workspace_resource\|get_large_workspace_resource" cpp/src/neighbors/detail/cagra/search_plan.cuh | head -20

Repository: rapidsai/cuvs

Length of output: 218


🏁 Script executed:

# Look for utils.hpp in cagra directory
find cpp/src/neighbors/detail/cagra -name "utils.hpp" -o -name "utils.cuh"

Repository: rapidsai/cuvs

Length of output: 98


🏁 Script executed:

# Check if there's a utils.hpp in the parent or sibling directories
find cpp/src/neighbors -name "utils.hpp" | head -10

Repository: rapidsai/cuvs

Length of output: 98


🏁 Script executed:

# Search for device_async_resource_ref assignments in cagra directory
grep -n "device_async_resource_ref" cpp/src/neighbors/detail/cagra/*.cuh

Repository: rapidsai/cuvs

Length of output: 298


🏁 Script executed:

# Check utils.hpp for the resource getter issue
grep -n "get_workspace_resource\|get_large_workspace_resource\|device_async_resource_ref" cpp/src/neighbors/detail/cagra/utils.hpp

Repository: rapidsai/cuvs

Length of output: 234


🏁 Script executed:

# Look at the context around the resource assignment in search_plan.cuh
sed -n '40,85p' cpp/src/neighbors/detail/cagra/search_plan.cuh

Repository: rapidsai/cuvs

Length of output: 1810


🏁 Script executed:

# Check if there are compilation tests or if code compiles successfully
find . -name "*.log" -o -name "*build*" -type d 2>/dev/null | head -5

Repository: rapidsai/cuvs

Length of output: 77


Fix API mismatch in workspace resource accessor calls.

The code assigns results from get_workspace_resource() and get_large_workspace_resource() (non-_ref versions) to variables typed as rmm::device_async_resource_ref. These non-_ref getters return raw pointers (rmm::mr::limiting_resource_adaptor*), not reference types. The modern RAFT API provides get_workspace_resource_ref() and get_large_workspace_resource_ref() which correctly return rmm::device_async_resource_ref and are the intended accessors for resource management.

Update the following locations:

  • cpp/src/neighbors/detail/cagra/cagra_build.cuh line 1703
  • cpp/src/neighbors/detail/cagra/search_plan.cuh lines 61, 83
  • cpp/src/neighbors/detail/cagra/utils.hpp line 193
🔧 Proposed fix example
-  rmm::device_async_resource_ref workspace_mr =
-    use_large_workspace ? raft::resource::get_large_workspace_resource(res)
-                        : raft::resource::get_workspace_resource(res);
+  rmm::device_async_resource_ref workspace_mr =
+    use_large_workspace ? raft::resource::get_large_workspace_resource_ref(res)
+                        : raft::resource::get_workspace_resource_ref(res);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
rmm::device_async_resource_ref workspace_mr =
use_large_workspace ? raft::resource::get_large_workspace_resource_ref(res)
: raft::resource::get_workspace_resource_ref(res);
use_large_workspace ? raft::resource::get_large_workspace_resource(res)
: raft::resource::get_workspace_resource(res);
rmm::device_async_resource_ref workspace_mr =
use_large_workspace ? raft::resource::get_large_workspace_resource_ref(res)
: raft::resource::get_workspace_resource_ref(res);
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@cpp/src/neighbors/detail/cagra/cagra_build.cuh` around lines 1703 - 1705, The
workspace resource variables are typed as rmm::device_async_resource_ref but
call the old pointer-returning getters; replace calls to
get_workspace_resource(...) and get_large_workspace_resource(...) with the RAFT
ref-returning APIs get_workspace_resource_ref(...) and
get_large_workspace_resource_ref(...) so the returned type matches
rmm::device_async_resource_ref. Update the assignment to workspace_mr in
cagra_build.cuh, and analogous calls in search_plan.cuh (both sites that set
workspace/resource refs) and the resource accessor use in utils.hpp (around the
symbol used at line ~193) to call the *_ref variants so types align with
rmm::device_async_resource_ref.


RAFT_LOG_DEBUG(
"IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u",
Expand Down Expand Up @@ -2123,7 +2123,7 @@ auto iterative_build_graph(
static_cast<int64_t>(dev_query_view.extent(1)),
max_chunk_size,
raft::resource::get_cuda_stream(res),
raft::resource::get_workspace_resource_ref(res));
raft::resource::get_workspace_resource(res));
for (const auto& batch : query_batch) {
auto batch_dev_query_view = raft::make_device_matrix_view<const T, int64_t>(
batch.data(), batch.size(), dev_query_view.extent(1));
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ void merge_graph_gpu(

namespace bli = cuvs::spatial::knn::detail::utils;
auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res);
auto workspace_mr = raft::resource::get_workspace_resource_ref(res);
auto workspace_mr = raft::resource::get_workspace_resource(res);

bli::batch_load_iterator<
raft::mdspan<IdxT, raft::matrix_extent<int64_t>, raft::row_major, AccessorOutputGraph>>
Expand Down Expand Up @@ -1001,7 +1001,7 @@ void sort_knn_graph(
const uint64_t input_graph_degree = knn_graph.extent(1);
IdxT* const input_graph_ptr = knn_graph.data_handle();

auto large_tmp_mr = raft::resource::get_large_workspace_resource_ref(res);
auto large_tmp_mr = raft::resource::get_large_workspace_resource(res);

auto d_input_graph = raft::make_device_mdarray<IdxT>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));
Expand Down Expand Up @@ -1625,7 +1625,7 @@ void prune_graph_gpu(

namespace bli = cuvs::spatial::knn::detail::utils;
auto [copy_stream, enable_prefetch] = bli::get_prefetch_stream(res);
auto workspace_mr = raft::resource::get_workspace_resource_ref(res);
auto workspace_mr = raft::resource::get_workspace_resource(res);

// Single-batch read-only iterator for the input graph (graph_size rows fit in one batch).
bli::batch_load_iterator<
Expand Down Expand Up @@ -1722,9 +1722,9 @@ void optimize(
raft::resources res{res_const};

// large temporary memory for large arrays, e.g. everything >= O(graph_size)
auto large_tmp_mr = raft::resource::get_large_workspace_resource_ref(res);
auto large_tmp_mr = raft::resource::get_large_workspace_resource(res);
// temporary memory for small arrays, e.g. everything <= O(batchsize * graph_degree)
auto default_ws_mr = raft::resource::get_workspace_resource_ref(res);
auto default_ws_mr = raft::resource::get_workspace_resource(res);

// create a stream pool if not already present
if (!res.has_resource_factory(raft::resource::resource_type::CUDA_STREAM_POOL) ||
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct lightweight_uvector {
if (new_size == size_) { return; }
if (std::holds_alternative<raft_res_type>(res_)) {
auto& h = std::get<raft_res_type>(res_);
res_ = rmm_res_type{raft::resource::get_workspace_resource_ref(*h),
res_ = rmm_res_type{raft::resource::get_workspace_resource(*h),
raft::resource::get_cuda_stream(*h)};
}
auto& [r, s] = std::get<rmm_res_type>(res_);
Expand All @@ -80,7 +80,7 @@ struct lightweight_uvector {
if (new_size == size_) { return; }
if (std::holds_alternative<raft_res_type>(res_)) {
auto& h = std::get<raft_res_type>(res_);
res_ = rmm_res_type{raft::resource::get_workspace_resource_ref(*h), stream};
res_ = rmm_res_type{raft::resource::get_workspace_resource(*h), stream};
} else {
std::get<rmm::cuda_stream_view>(std::get<rmm_res_type>(res_)) = stream;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class device_matrix_view_from_host {
// live on stack and not returned to a user.
// The user may opt to set this resource to managed memory to allow large allocations.
device_mem_.emplace(raft::make_device_mdarray<T, IdxT>(
res, raft::resource::get_large_workspace_resource_ref(res), host_view.extents()));
res, raft::resource::get_large_workspace_resource(res), host_view.extents()));
raft::copy(res, device_mem_->view(), host_view);
device_ptr = device_mem_->data_handle();
}
Expand Down
30 changes: 15 additions & 15 deletions cpp/src/neighbors/detail/vamana/vamana_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,19 @@ void batched_insert_vamana(
auto query_ids = raft::make_device_vector<IdxT>(res, max_batchsize);
auto query_list_ptr = raft::make_device_mdarray<QueryCandidates<IdxT, accT>>(
res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(max_batchsize + 1));
QueryCandidates<IdxT, accT>* query_list =
static_cast<QueryCandidates<IdxT, accT>*>(query_list_ptr.data_handle());

// Results of each batch of inserts during build - Memory is used by query_list structure
auto visited_ids =
raft::make_device_mdarray<IdxT>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(max_batchsize, visited_size));
auto visited_dists =
raft::make_device_mdarray<accT>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(max_batchsize, visited_size));

// Assign memory to query_list structures and initialize
Expand All @@ -167,14 +167,14 @@ void batched_insert_vamana(
1);
auto topk_pq_mem =
raft::make_device_mdarray<Node<accT>>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(max_batchsize, visited_size));

int align_padding = raft::alignTo(dim, 16) - dim;

auto s_coords_mem = raft::make_device_mdarray<T>(
res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(min(maxBlocks, max(max_batchsize, reverse_batch)),
dim + align_padding));

Expand Down Expand Up @@ -318,7 +318,7 @@ void batched_insert_vamana(

// compute prefix sums of query_list sizes - TODO parallelize prefix sums
// auto d_total_edges = raft::make_device_mdarray<int>(
// res, raft::resource::get_workspace_resource_ref(res), raft::make_extents<int64_t>(1));
// res, raft::resource::get_workspace_resource(res), raft::make_extents<int64_t>(1));
rmm::device_scalar<int> d_total_edges(stream);
prefix_sums_sizes<accT, IdxT><<<1, 1, 0, stream>>>(query_list, step_size, d_total_edges.data());
RAFT_CUDA_TRY(cudaPeekAtLastError());
Expand All @@ -329,16 +329,16 @@ void batched_insert_vamana(

auto edge_dist_pair = raft::make_device_mdarray<DistPair<IdxT, accT>>(
res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(total_edges));

auto edge_dest =
raft::make_device_mdarray<IdxT>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(total_edges));
auto edge_src =
raft::make_device_mdarray<IdxT>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(total_edges));

// Create reverse edge list
Expand Down Expand Up @@ -367,7 +367,7 @@ void batched_insert_vamana(

auto temp_sort_storage = raft::make_device_mdarray<IdxT>(
res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(temp_storage_bytes / sizeof(IdxT)));

// Sort to group reverse edges by destination
Expand Down Expand Up @@ -406,7 +406,7 @@ void batched_insert_vamana(

auto temp_sort_storage = raft::make_device_mdarray<IdxT>(
res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(temp_storage_bytes / sizeof(IdxT)));

// Sort to group reverse edges by destination
Expand Down Expand Up @@ -451,16 +451,16 @@ void batched_insert_vamana(
// Allocate reverse QueryCandidate list based on number of unique destinations
auto reverse_list_ptr = raft::make_device_mdarray<QueryCandidates<IdxT, accT>>(
res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(reverse_batch));
auto rev_ids =
raft::make_device_mdarray<IdxT>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(reverse_batch, visited_size));

auto rev_dists =
raft::make_device_mdarray<accT>(res,
raft::resource::get_large_workspace_resource_ref(res),
raft::resource::get_large_workspace_resource(res),
raft::make_extents<int64_t>(reverse_batch, visited_size));

QueryCandidates<IdxT, accT>* reverse_list =
Expand Down Expand Up @@ -663,7 +663,7 @@ index<T, IdxT> build(
static_cast<int64_t>(dim),
static_cast<size_t>(max_batch_size),
raft::resource::get_cuda_stream(res),
raft::resource::get_workspace_resource_ref(res));
raft::resource::get_workspace_resource(res));
for (const auto& batch : _vamana_batches) {
// perform rotation
auto dataset_rotated = raft::make_device_matrix<float, int64_t>(res, batch.size(), dim);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/vpq_dataset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ void process_and_fill_codes_subspaces(
static_cast<ix_t>(dim),
static_cast<size_t>(max_batch_size),
copy_stream,
raft::resource::get_workspace_resource_ref(res),
raft::resource::get_workspace_resource(res),
enable_prefetch_stream);
vec_batches.prefetch_next_batch();
for (const auto& batch : vec_batches) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ template <typename Index>
void recompute_internal_state(const raft::resources& res, Index& index)
{
auto stream = raft::resource::get_cuda_stream(res);
auto tmp_res = raft::resource::get_workspace_resource_ref(res);
auto tmp_res = raft::resource::get_workspace_resource(res);
rmm::device_uvector<uint32_t> sorted_sizes(index.n_lists(), stream, tmp_res);

// Actualize the list pointers
Expand Down
Loading
Loading