Skip to content

Commit 4e4c00c

Browse files
committed
add exploration support
1 parent 92556d3 commit 4e4c00c

6 files changed

Lines changed: 287 additions & 205 deletions

File tree

CMakeLists.txt

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,20 @@ if (MSVC)
226226
if(NOT CMAKE_VS_MSBUILD_COMMAND)
227227
find_program(CMAKE_VS_MSBUILD_COMMAND NAMES msbuild)
228228
endif()
229-
if(NOT CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION)
230-
set(CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION ${CMAKE_SYSTEM_VERSION})
229+
230+
set(GPERFTOOLS_MSBUILD_ARGS
231+
gperftools.sln /m /nologo
232+
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
233+
/property:Platform="x64"
234+
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION})
235+
236+
if(CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION)
237+
list(APPEND GPERFTOOLS_MSBUILD_ARGS /p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION})
231238
endif()
239+
232240
add_custom_target(build_libtcmalloc_minimal DEPENDS ${TCMALLOC_LINK_LIBRARY})
233241
add_custom_command(OUTPUT ${TCMALLOC_LINK_LIBRARY}
234-
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} gperftools.sln /m /nologo
235-
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
236-
/property:Platform="x64"
237-
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION}
238-
/p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION}
242+
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} ${GPERFTOOLS_MSBUILD_ARGS}
239243
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/gperftools)
240244
241245
add_library(libtcmalloc_minimal_for_exe STATIC IMPORTED)

apps/benchmark/include/benchmark.h

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,6 @@
1313
namespace diskann::benchmark
1414
{
1515

16-
// Helper to access protected members of Index for exploration tests
17-
template <typename T, typename TagT, typename LabelT> class IndexExplorer : public diskann::Index<T, TagT, LabelT>
18-
{
19-
public:
20-
static std::pair<uint32_t, uint32_t> explore(diskann::Index<T, TagT, LabelT> *index, const T *query,
21-
const uint32_t L, const std::vector<uint32_t> &init_ids,
22-
diskann::InMemQueryScratch<T> *scratch)
23-
{
24-
const std::vector<LabelT> unused_filters;
25-
return ((IndexExplorer *)index)
26-
->iterate_to_fixed_point(scratch->aligned_query(), L, init_ids, scratch, false, unused_filters, true);
27-
}
28-
};
29-
3016
template <typename T, typename TagT, typename LabelT>
3117
static void test_diskann_anns(diskann::Index<T, TagT, LabelT> *index, const T *query_data, size_t query_num,
3218
size_t query_dim, size_t query_aligned_dim,
@@ -94,6 +80,9 @@ static void test_diskann_anns(diskann::Index<T, TagT, LabelT> *index, const T *q
9480

9581
log("L_search %4u, Recall@%u %.4f, QPS/thread %8.2f, Mean Latency %6.2f us, 99.9 Latency %8.2f us\n", L, k,
9682
recall, qps_per_thread, mean_latency, p999_latency);
83+
84+
if (recall >= 0.997f)
85+
break;
9786
}
9887
}
9988

@@ -105,36 +94,33 @@ static void test_diskann_explore(diskann::Index<T, TagT, LabelT> *index, const T
10594
{
10695
log("Testing Exploration (k=%u)...\n", k);
10796

108-
for (uint32_t L_base : {100, 1000, 10000})
97+
uint32_t k_factor = 100;
98+
for (uint32_t f = 0; f <= 2; f++, k_factor *= 10)
10999
{
110-
for (uint32_t factor : {1, 2, 4, 6, 8, 10})
100+
for (uint32_t i = (f == 0) ? 1 : 2; i < 11; i++)
111101
{
112-
uint32_t L = L_base * factor;
113-
if (L < k)
114-
L = k;
102+
uint32_t max_distance_count = ((f == 0) ? (k + k_factor * (i - 1)) : (k_factor * i));
115103

116104
size_t correct = 0;
117105
size_t total = 0;
118106
auto start = std::chrono::high_resolution_clock::now();
119107

120-
for (size_t i = 0; i < explore_query_num; i++)
108+
for (size_t q = 0; q < explore_query_num; q++)
121109
{
122-
if (i >= entry_node_indices.size() || entry_node_indices[i].empty())
110+
if (q >= entry_node_indices.size() || entry_node_indices[q].empty())
123111
continue;
124112

125-
// For exploration, we use search_with_tags which is a public proxy for exploration.
126-
// In DiskANN, specific entry points are harder to set via public API, so we sweep L
127-
// to show how search performance evolves.
128-
113+
uint32_t entry_point = entry_node_indices[q][0];
129114
std::vector<TagT> results(k);
130115
std::vector<float> dists(k);
131-
std::vector<T *> res_vecs;
132-
index->search_with_tags(explore_query_data + i * explore_query_aligned_dim, k, L, results.data(),
133-
dists.data(), res_vecs);
134116

135-
if (i < ground_truth.size())
117+
index->explore_with_tags(explore_query_data + q * explore_query_aligned_dim, (uint64_t)k,
118+
max_distance_count, max_distance_count, entry_point, results.data(),
119+
dists.data());
120+
121+
if (q < ground_truth.size())
136122
{
137-
const auto &gt = ground_truth[i];
123+
const auto &gt = ground_truth[q];
138124
for (size_t r = 0; r < k; r++)
139125
{
140126
uint32_t id = static_cast<uint32_t>(results[r]) - 1;
@@ -151,10 +137,13 @@ static void test_diskann_explore(diskann::Index<T, TagT, LabelT> *index, const T
151137
uint64_t time_per_query =
152138
explore_query_num > 0 ? (uint64_t)(diff.count() * 1000000 / explore_query_num) : 0;
153139

154-
log("L_search %6u, Recall@%u %.6f, time_us_per_query %4llu us\n", L, k, recall,
140+
log("max_distance_count %7u, Recall@%u %.6f, time_us_per_query %6llu us\n", max_distance_count, k, recall,
155141
(unsigned long long)time_per_query);
156-
if (recall >= 0.995f)
142+
143+
if (recall >= 0.997f)
144+
{
157145
return;
146+
}
158147
}
159148
}
160149
}

0 commit comments

Comments
 (0)