-
Notifications
You must be signed in to change notification settings - Fork 184
Expand file tree
/
Copy pathcagra.cpp
More file actions
59 lines (53 loc) · 2.24 KB
/
cagra.cpp
File metadata and controls
59 lines (53 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/common.hpp>
namespace cuvs::neighbors::cagra {
inline auto graph_params_heuristic(raft::matrix_extent<int64_t> dataset,
int intermediate_graph_degree,
int ef_construction,
cuvs::distance::DistanceType metric)
-> decltype(index_params::graph_build_params)
{
if (dataset.extent(0) < int64_t(1e6)) {
// Use NN descent for smaller datasets
auto nn_descent_params =
graph_build_params::nn_descent_params(intermediate_graph_degree, metric);
nn_descent_params.max_iterations = 5 + ef_construction / 16;
return nn_descent_params;
} else {
// Otherwise, use IVF-PQ
auto ivf_pq_params = cuvs::neighbors::graph_build_params::ivf_pq_params(dataset, metric);
ivf_pq_params.search_params.n_probes =
std::round(2 + std::sqrt(ivf_pq_params.build_params.n_lists) / 20 + ef_construction / 16);
return ivf_pq_params;
}
}
cagra::index_params index_params::from_hnsw_params(raft::matrix_extent<int64_t> dataset,
int M,
int ef_construction,
hnsw_heuristic_type heuristic,
cuvs::distance::DistanceType metric)
{
cagra::index_params params;
switch (heuristic) {
case hnsw_heuristic_type::SAME_GRAPH_FOOTPRINT:
params.graph_degree = M * 2;
params.intermediate_graph_degree = M * 3;
params.variable_graph_degree_fraction = 0.35;
break;
case hnsw_heuristic_type::SIMILAR_SEARCH_PERFORMANCE:
default:
params.graph_degree = M;
params.intermediate_graph_degree = M + M * ef_construction / 256;
params.variable_graph_degree_fraction = 0.7;
break;
}
params.graph_build_params =
graph_params_heuristic(dataset, params.intermediate_graph_degree, ef_construction, metric);
return params;
}
} // namespace cuvs::neighbors::cagra