11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2023-2025 , NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2023-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55#pragma once
@@ -351,15 +351,9 @@ void bench_search(::benchmark::State& state,
351351
352352 // Each thread calculates recall on their partition of queries.
353353 // evaluate recall
354- if (dataset->max_k () >= k) {
355- const std::int32_t * gt = dataset->gt_set ();
356- const std::uint32_t * filter_bitset = dataset->filter_bitset (MemoryType::kHostMmap );
357- auto filter = [filter_bitset](std::int32_t i) -> bool {
358- if (filter_bitset == nullptr ) { return true ; }
359- auto word = filter_bitset[i >> 5 ];
360- return word & (1 << (i & 31 ));
361- };
362- const std::uint32_t max_k = dataset->max_k ();
354+ if (dataset->max_k () >= k && dataset->gt_maps ().has_value ()) {
355+ // gt_maps[i] is a hash map of {id, neighbor_rank} for query i
356+ const auto & gt_maps = dataset->gt_maps ();
363357 result_buf.transfer_data (MemoryType::kHost , current_algo_props->query_memory_type );
364358 auto * neighbors_host = reinterpret_cast <index_type*>(result_buf.data (MemoryType::kHost ));
365359 std::size_t rows = std::min (queries_processed, query_set_size);
@@ -369,39 +363,49 @@ void bench_search(::benchmark::State& state,
369363 // We go through the groundtruth with same stride as the benchmark loop.
370364 size_t out_offset = 0 ;
371365 size_t batch_offset = (state.thread_index () * n_queries) % query_set_size;
366+ // Avoid CPU oversubscription when parallelizing recall calculation loop
367+ int num_recall_calculation_worker_threads =
368+ std::thread::hardware_concurrency () / benchmark_n_threads - 1 ; // -1 for the main thread
369+ // ensure non-negative number of workers (possible if hardware_concurrency()
370+ // does not return an expected value) by clamping to 0
371+ if (num_recall_calculation_worker_threads < 0 ) { num_recall_calculation_worker_threads = 0 ; }
372372 while (out_offset < rows) {
373- for (std::size_t i = 0 ; i < n_queries; i++) {
374- size_t i_orig_idx = batch_offset + i;
375- size_t i_out_idx = out_offset + i;
376- if (i_out_idx < rows) {
377- /* NOTE: recall correctness & filtering
378-
379- In the loop below, we filter the ground truth values on-the-fly.
380- We need enough ground truth values to compute recall correctly though.
381- But the ground truth file only contains `max_k` values per row; if there are less valid
382- values than k among them, we overestimate the recall. Essentially, we compare the first
383- `filter_pass_count` values of the algorithm output, and this counter can be less than `k`.
384- In the extreme case of very high filtering rate, we may be bypassing entire rows of
385- results. However, this is still better than no recall estimate at all.
386-
387- TODO: consider generating the filtered ground truth on-the-fly
388- */
389- uint32_t filter_pass_count = 0 ;
390- for (std::uint32_t l = 0 ; l < max_k && filter_pass_count < k; l++) {
391- auto exp_idx = gt[i_orig_idx * max_k + l];
392- if (!filter (exp_idx)) { continue ; }
393- filter_pass_count++;
394- for (std::uint32_t j = 0 ; j < k; j++) {
395- auto act_idx = static_cast <std::int32_t >(neighbors_host[i_out_idx * k + j]);
396- if (act_idx == exp_idx) {
397- match_count++;
398- break ;
399- }
400- }
373+ std::vector<std::thread> recall_calculation_workers;
374+ recall_calculation_workers.reserve (num_recall_calculation_worker_threads);
375+ std::vector<std::size_t > local_match_count (num_recall_calculation_worker_threads + 1 );
376+ std::vector<std::size_t > local_total_count (num_recall_calculation_worker_threads + 1 );
377+ int chunk_size =
378+ n_queries / (num_recall_calculation_worker_threads + 1 ); // +1 for the main thread
379+ int remainder = n_queries % (num_recall_calculation_worker_threads + 1 );
380+ auto recall_calculation = [&](int start, int end, int tid) -> void {
381+ for (int i = start; i < end; ++i) {
382+ size_t i_orig_idx = batch_offset + i;
383+ size_t i_out_idx = out_offset + i;
384+ if (i_out_idx < rows) {
385+ auto * candidates = neighbors_host + i_out_idx * k;
386+ auto [matching, total] = gt_maps->count_matches (i_orig_idx, candidates, k);
387+ local_match_count[tid] += matching;
388+ local_total_count[tid] += total;
401389 }
402- total_count += filter_pass_count;
403390 }
391+ };
392+ // launch worker threads
393+ int start = 0 ;
394+ for (int tid = 0 ; tid < num_recall_calculation_worker_threads; tid++) {
395+ int end = start + chunk_size;
396+ if (tid < remainder) { ++end; }
397+ recall_calculation_workers.emplace_back (recall_calculation, start, end, tid);
398+ start = end;
404399 }
400+ // main thread works on last chunk
401+ recall_calculation (start, n_queries, num_recall_calculation_worker_threads);
402+ // join all worker threads
403+ for (auto & worker : recall_calculation_workers) {
404+ worker.join ();
405+ }
406+ match_count += std::accumulate (local_match_count.begin (), local_match_count.end (), 0 );
407+ total_count += std::accumulate (local_total_count.begin (), local_total_count.end (), 0 );
408+
405409 out_offset += n_queries;
406410 batch_offset = (batch_offset + queries_stride) % query_set_size;
407411 }
0 commit comments