Skip to content

Commit 5f15f19

Browse files
authored
Speed up recall calculation in cuVS Bench for large top-K (#1816)
Currently, recall calculation in cuVS Bench essentially runs an outer for loop over the `k` ground-truth vector IDs and an inner loop over the `k` ANN result vector IDs, incrementing a counter if the computed value matches the ground truth. This works well assuming `k` is small but the complexity is `O(k^2)`. When benchmarking use cases involving large `k` values, the recall calculation becomes a bottleneck especially since a large `k` does not necessarily lead to much slower search times, so the recall calculation is performed about as many times as would be for a small `k`, leading to unacceptable (or at least humanly unbearable) run times. This update speeds up the recall calculation in the following ways: 1. Eager hashing of vector IDs - During the construction of the dataset, we populate for each query a `std::unordered_map` of {vector_id, neighbor_rank}. This step has complexity `O(k)` and the hash maps are cached for all benchmark cases. - During search, we look up the hash of each search result in the ground truth map to determine whether it is a true result. This step has complexity `O(k)`. 2. Parallelizing hash map build and lookup - We use basic threading to parallelize recall calculation at the query level (for ease of implementation and cache locality). - Care is taken to avoid oversubscribing the CPU when benchmarking is run on multiple threads e.g. in throughput mode. 3. Capping the total number of queries for which recall is calculated to about 10,000 - This avoids unbounded recall calculations if using large sets of queries and ground truths while performing many iterations of the benchmark case. - Underlying assumption is that the sample of queries used for recall calculation will be representative of the recall performance for the benchmark case tested. Testing at k=15000, batch-size=500, iterations=20, cpu=AMD EPYC 7413 24 cores/48 threads: - baseline wall time: 285 s - improved wall time: 3.7 s - Note that the wall time includes loading and running the benchmarks, which takes over 1 s for these settings. - Also note that if the number of iterations is not specified, the benchmark would run for over 100 iterations which would make the baseline runtime much slower as recall calculation is performed for far more than 10,000 queries. - At k=10, wall times are 1.369 s (PR) vs. 1.362 s (baseline) Authors: - James Xia (https://github.com/jamxia155) - Anupam (https://github.com/aamijar) Approvers: - Artem M. Chirkin (https://github.com/achirkin) URL: #1816
1 parent 74681b5 commit 5f15f19

2 files changed

Lines changed: 158 additions & 52 deletions

File tree

cpp/bench/ann/src/common/benchmark.hpp

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55
#pragma once
@@ -351,15 +351,9 @@ void bench_search(::benchmark::State& state,
351351

352352
// Each thread calculates recall on their partition of queries.
353353
// evaluate recall
354-
if (dataset->max_k() >= k) {
355-
const std::int32_t* gt = dataset->gt_set();
356-
const std::uint32_t* filter_bitset = dataset->filter_bitset(MemoryType::kHostMmap);
357-
auto filter = [filter_bitset](std::int32_t i) -> bool {
358-
if (filter_bitset == nullptr) { return true; }
359-
auto word = filter_bitset[i >> 5];
360-
return word & (1 << (i & 31));
361-
};
362-
const std::uint32_t max_k = dataset->max_k();
354+
if (dataset->max_k() >= k && dataset->gt_maps().has_value()) {
355+
// gt_maps[i] is a hash map of {id, neighbor_rank} for query i
356+
const auto& gt_maps = dataset->gt_maps();
363357
result_buf.transfer_data(MemoryType::kHost, current_algo_props->query_memory_type);
364358
auto* neighbors_host = reinterpret_cast<index_type*>(result_buf.data(MemoryType::kHost));
365359
std::size_t rows = std::min(queries_processed, query_set_size);
@@ -369,39 +363,49 @@ void bench_search(::benchmark::State& state,
369363
// We go through the groundtruth with same stride as the benchmark loop.
370364
size_t out_offset = 0;
371365
size_t batch_offset = (state.thread_index() * n_queries) % query_set_size;
366+
// Avoid CPU oversubscription when parallelizing recall calculation loop
367+
int num_recall_calculation_worker_threads =
368+
std::thread::hardware_concurrency() / benchmark_n_threads - 1; // -1 for the main thread
369+
// ensure non-negative number of workers (possible if hardware_concurrency()
370+
// does not return an expected value) by clamping to 0
371+
if (num_recall_calculation_worker_threads < 0) { num_recall_calculation_worker_threads = 0; }
372372
while (out_offset < rows) {
373-
for (std::size_t i = 0; i < n_queries; i++) {
374-
size_t i_orig_idx = batch_offset + i;
375-
size_t i_out_idx = out_offset + i;
376-
if (i_out_idx < rows) {
377-
/* NOTE: recall correctness & filtering
378-
379-
In the loop below, we filter the ground truth values on-the-fly.
380-
We need enough ground truth values to compute recall correctly though.
381-
But the ground truth file only contains `max_k` values per row; if there are less valid
382-
values than k among them, we overestimate the recall. Essentially, we compare the first
383-
`filter_pass_count` values of the algorithm output, and this counter can be less than `k`.
384-
In the extreme case of very high filtering rate, we may be bypassing entire rows of
385-
results. However, this is still better than no recall estimate at all.
386-
387-
TODO: consider generating the filtered ground truth on-the-fly
388-
*/
389-
uint32_t filter_pass_count = 0;
390-
for (std::uint32_t l = 0; l < max_k && filter_pass_count < k; l++) {
391-
auto exp_idx = gt[i_orig_idx * max_k + l];
392-
if (!filter(exp_idx)) { continue; }
393-
filter_pass_count++;
394-
for (std::uint32_t j = 0; j < k; j++) {
395-
auto act_idx = static_cast<std::int32_t>(neighbors_host[i_out_idx * k + j]);
396-
if (act_idx == exp_idx) {
397-
match_count++;
398-
break;
399-
}
400-
}
373+
std::vector<std::thread> recall_calculation_workers;
374+
recall_calculation_workers.reserve(num_recall_calculation_worker_threads);
375+
std::vector<std::size_t> local_match_count(num_recall_calculation_worker_threads + 1);
376+
std::vector<std::size_t> local_total_count(num_recall_calculation_worker_threads + 1);
377+
int chunk_size =
378+
n_queries / (num_recall_calculation_worker_threads + 1); // +1 for the main thread
379+
int remainder = n_queries % (num_recall_calculation_worker_threads + 1);
380+
auto recall_calculation = [&](int start, int end, int tid) -> void {
381+
for (int i = start; i < end; ++i) {
382+
size_t i_orig_idx = batch_offset + i;
383+
size_t i_out_idx = out_offset + i;
384+
if (i_out_idx < rows) {
385+
auto* candidates = neighbors_host + i_out_idx * k;
386+
auto [matching, total] = gt_maps->count_matches(i_orig_idx, candidates, k);
387+
local_match_count[tid] += matching;
388+
local_total_count[tid] += total;
401389
}
402-
total_count += filter_pass_count;
403390
}
391+
};
392+
// launch worker threads
393+
int start = 0;
394+
for (int tid = 0; tid < num_recall_calculation_worker_threads; tid++) {
395+
int end = start + chunk_size;
396+
if (tid < remainder) { ++end; }
397+
recall_calculation_workers.emplace_back(recall_calculation, start, end, tid);
398+
start = end;
404399
}
400+
// main thread works on last chunk
401+
recall_calculation(start, n_queries, num_recall_calculation_worker_threads);
402+
// join all worker threads
403+
for (auto& worker : recall_calculation_workers) {
404+
worker.join();
405+
}
406+
match_count += std::accumulate(local_match_count.begin(), local_match_count.end(), 0);
407+
total_count += std::accumulate(local_total_count.begin(), local_total_count.end(), 0);
408+
405409
out_offset += n_queries;
406410
batch_offset = (batch_offset + queries_stride) % query_set_size;
407411
}

cpp/bench/ann/src/common/dataset.hpp

Lines changed: 115 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55
#pragma once
@@ -14,6 +14,7 @@
1414
#include <optional>
1515
#include <random>
1616
#include <string>
17+
#include <thread>
1718

1819
namespace cuvs::bench {
1920

@@ -33,19 +34,121 @@ void generate_bernoulli(CarrierT* data, size_t words, double p)
3334
}
3435
};
3536

37+
template <typename T>
38+
struct ground_truth_map {
39+
using bitset_carrier_type = uint32_t;
40+
static constexpr uint32_t kMaxQueriesForRecall = 10'000;
41+
42+
explicit ground_truth_map(std::string file_name,
43+
uint32_t n_queries,
44+
std::optional<blob<bitset_carrier_type>>& filter_bitset)
45+
: gt_maps_(n_queries)
46+
{
47+
// Eagerly iterate over and optionally filter the ground truth set to build gt_maps_ for up to
48+
// kMaxQueriesForRecall queries
49+
/* NOTE: recall correctness & filtering
50+
51+
We generate the filtered ground truth values and build unordered_maps with them to
52+
enable O(1) lookup. We need enough ground truth values to compute recall correctly
53+
though. But the ground truth file only contains `max_k_` values per row; if there are
54+
less valid values than k among them, we overestimate the recall. Essentially, we compare
55+
the first `gt_maps_[query_idx].size()` values of the algorithm output, and this value can be
56+
less than `k`. In the extreme case of very high filtering rate, we may be bypassing
57+
entire rows of results. However, this is still better than no recall estimate at all.
58+
59+
*/
60+
auto ground_truth_set = blob<T>(file_name);
61+
max_k_ = ground_truth_set.n_cols();
62+
auto filter = [&](T i) -> bool {
63+
if (!filter_bitset.has_value()) { return true; }
64+
// bitset is `32 = bitset_carrier_type * 8` times more dense than the data
65+
// use bitwise arithmetic to get the `row_id` and correct bit pos in the `word`
66+
auto word = filter_bitset->data(MemoryType::kHostMmap)[i >> 5];
67+
return word & (1 << (i & 31));
68+
};
69+
// Avoid CPU oversubscription when parallelizing recall calculation loop
70+
int num_map_building_worker_threads =
71+
std::thread::hardware_concurrency() - 1; // -1 for the main thread
72+
// ensure non-negative number of workers (possible if hardware_concurrency()
73+
// does not return an expected value) by clamping to 0
74+
if (num_map_building_worker_threads < 0) { num_map_building_worker_threads = 0; }
75+
std::vector<std::thread> gt_map_building_workers;
76+
gt_map_building_workers.reserve(num_map_building_worker_threads);
77+
int chunk_size = n_queries / (num_map_building_worker_threads + 1);
78+
int remainder = n_queries % (num_map_building_worker_threads + 1);
79+
int stride = (n_queries - 1) / kMaxQueriesForRecall + 1; // round-up division
80+
auto build_gt_map = [&](int start, int end, int tid) -> void {
81+
for (int query_idx = start; query_idx < end; ++query_idx) {
82+
if (query_idx % stride) continue;
83+
for (std::uint32_t neighbor_rank = 0; neighbor_rank < max_k_; ++neighbor_rank) {
84+
auto id = ground_truth_set.data()[query_idx * max_k_ + neighbor_rank];
85+
if (!filter(id)) { continue; }
86+
if (gt_maps_[query_idx].count(id)) {
87+
throw std::invalid_argument(
88+
"Duplicate neighbor id found in ground truth set for query " +
89+
std::to_string(query_idx));
90+
}
91+
gt_maps_[query_idx][id] = neighbor_rank;
92+
}
93+
}
94+
};
95+
// launch worker threads
96+
int start = 0;
97+
for (int tid = 0; tid < num_map_building_worker_threads; tid++) {
98+
int end = start + chunk_size;
99+
if (tid < remainder) { ++end; }
100+
gt_map_building_workers.emplace_back(build_gt_map, start, end, tid);
101+
start = end;
102+
}
103+
// main thread works on last chunk
104+
build_gt_map(start, n_queries, num_map_building_worker_threads);
105+
// join all worker threads
106+
for (auto& worker : gt_map_building_workers) {
107+
worker.join();
108+
}
109+
}
110+
111+
[[nodiscard]] auto max_k() const -> uint32_t { return max_k_; }
112+
113+
template <typename index_type>
114+
[[nodiscard]] auto count_matches(size_t query_idx, const index_type* candidates, uint32_t k) const
115+
-> std::pair<size_t, size_t>
116+
{
117+
if (query_idx >= gt_maps_.size() || gt_maps_[query_idx].empty()) return {0, 0};
118+
119+
size_t matching = 0;
120+
for (uint32_t i = 0; i < k; ++i) {
121+
auto act_idx = candidates[i];
122+
if (gt_maps_[query_idx].count(act_idx) &&
123+
static_cast<uint32_t>(gt_maps_[query_idx].at(act_idx)) < k) {
124+
++matching;
125+
}
126+
}
127+
size_t total = std::min(gt_maps_[query_idx].size(), static_cast<size_t>(k));
128+
return {matching, total};
129+
}
130+
131+
private:
132+
// Hash maps of {id, neighbor_rank} for up to kMaxQueriesForRecall queries in the ground truth set
133+
// e.g. gt_maps_[i][j] = k means that for the i-th query in the ground truth set, the neighbor
134+
// with idx j is the k-th nearest. Note that the nearest neighbor rank starts from 0.
135+
std::vector<std::unordered_map<T, T>> gt_maps_;
136+
uint32_t max_k_ = 0; // number of nearest neighbors in the ground truth
137+
};
138+
36139
template <typename DataT, typename IdxT = int32_t>
37140
struct dataset {
38141
public:
39-
using bitset_carrier_type = uint32_t;
142+
using bitset_carrier_type = typename ground_truth_map<IdxT>::bitset_carrier_type;
40143
static inline constexpr size_t kBitsPerCarrierValue = sizeof(bitset_carrier_type) * 8;
41144

42145
private:
43146
std::string name_;
44147
std::string distance_;
45148
blob<DataT> base_set_;
46149
blob<DataT> query_set_;
47-
std::optional<blob<IdxT>> ground_truth_set_;
48150
std::optional<blob<bitset_carrier_type>> filter_bitset_;
151+
std::optional<ground_truth_map<IdxT>> ground_truth_map_;
49152

50153
// Protects the lazy mutations of the blobs accessed by multiple threads
51154
mutable std::mutex mutex_;
@@ -73,10 +176,7 @@ struct dataset {
73176
: name_{std::move(name)},
74177
distance_{std::move(distance)},
75178
base_set_{base_file, subset_first_row, subset_size},
76-
query_set_{query_file},
77-
ground_truth_set_{groundtruth_neighbors_file.has_value()
78-
? std::make_optional<blob<IdxT>>(groundtruth_neighbors_file.value())
79-
: std::nullopt}
179+
query_set_{query_file}
80180
{
81181
if (filtering_rate.has_value()) {
82182
// Generate a random bitset for filtering
@@ -94,6 +194,11 @@ struct dataset {
94194
1.0 - filtering_rate.value());
95195
filter_bitset_.emplace(std::move(bitset_blob));
96196
}
197+
198+
if (groundtruth_neighbors_file.has_value()) {
199+
ground_truth_map_.emplace(ground_truth_map<IdxT>{
200+
groundtruth_neighbors_file.value(), query_set_.n_rows(), filter_bitset_});
201+
}
97202
}
98203

99204
[[nodiscard]] auto name() const -> std::string { return name_; }
@@ -118,8 +223,7 @@ struct dataset {
118223
}
119224
[[nodiscard]] auto max_k() const -> uint32_t
120225
{
121-
std::lock_guard<std::mutex> lock(mutex_);
122-
if (ground_truth_set_.has_value()) { return ground_truth_set_->n_cols(); }
226+
if (ground_truth_map_.has_value()) { return ground_truth_map_->max_k(); }
123227
return 0;
124228
}
125229
[[nodiscard]] auto base_set_size() const -> size_t
@@ -137,11 +241,9 @@ struct dataset {
137241
return r;
138242
}
139243

140-
[[nodiscard]] auto gt_set() const -> const IdxT*
244+
[[nodiscard]] auto gt_maps() const -> const std::optional<ground_truth_map<IdxT>>&
141245
{
142-
std::lock_guard<std::mutex> lock(mutex_);
143-
if (ground_truth_set_.has_value()) { return ground_truth_set_->data(); }
144-
return nullptr;
246+
return ground_truth_map_;
145247
}
146248

147249
[[nodiscard]] auto query_set() const -> const DataT*

0 commit comments

Comments
 (0)