Skip to content

Commit 3828aa5

Browse files
remove override range_search from hnsw, use iterator-based instead
Signed-off-by: min.tian <min.tian.cn@gmail.com>
1 parent 8a705a0 commit 3828aa5

1 file changed

Lines changed: 0 additions & 165 deletions

File tree

src/index/hnsw/faiss_hnsw.cc

Lines changed: 0 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,171 +1326,6 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
13261326
return res;
13271327
}
13281328

1329-
expected<DataSetPtr>
1330-
RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
1331-
if (this->indexes.empty()) {
1332-
return expected<DataSetPtr>::Err(Status::empty_index, "index not loaded");
1333-
}
1334-
for (const auto& index : indexes) {
1335-
if (index == nullptr) {
1336-
return expected<DataSetPtr>::Err(Status::empty_index, "index not loaded");
1337-
}
1338-
if (!index->is_trained) {
1339-
return expected<DataSetPtr>::Err(Status::index_not_trained, "index not trained");
1340-
}
1341-
}
1342-
1343-
const auto dim = dataset->GetDim();
1344-
const auto rows = dataset->GetRows();
1345-
const auto* data = dataset->GetTensor();
1346-
1347-
const auto hnsw_cfg = static_cast<const FaissHnswConfig&>(*cfg);
1348-
auto index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset);
1349-
if (index_id < 0) {
1350-
return expected<DataSetPtr>::Err(Status::invalid_args, "partition key value not correctly set");
1351-
}
1352-
1353-
const bool is_similarity_metric = faiss::is_similarity_metric(indexes[index_id]->metric_type);
1354-
1355-
const float radius = hnsw_cfg.radius.value();
1356-
const float range_filter = hnsw_cfg.range_filter.value();
1357-
1358-
feder::hnsw::FederResultUniq feder_result;
1359-
if (hnsw_cfg.trace_visit.value()) {
1360-
if (rows != 1) {
1361-
return expected<DataSetPtr>::Err(Status::invalid_args, "a single query vector is required");
1362-
}
1363-
feder_result = std::make_unique<feder::hnsw::FederResult>();
1364-
}
1365-
1366-
// check for brute-force search
1367-
auto whether_bf_search = WhetherPerformBruteForceRangeSearch(indexes[index_id].get(), hnsw_cfg, bitset);
1368-
1369-
if (!whether_bf_search.has_value()) {
1370-
return expected<DataSetPtr>::Err(Status::invalid_args, "ef parameter is missing");
1371-
}
1372-
1373-
// whether a user wants a refine
1374-
const bool whether_to_enable_refine = true;
1375-
1376-
// set up an index wrapper
1377-
auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper(
1378-
indexes[index_id].get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine);
1379-
1380-
if (index_wrapper == nullptr) {
1381-
return expected<DataSetPtr>::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW");
1382-
}
1383-
1384-
faiss::Index* index_wrapper_ptr = index_wrapper.get();
1385-
1386-
// set up faiss search parameters
1387-
knowhere::SearchParametersHNSWWrapper hnsw_search_params;
1388-
1389-
if (hnsw_cfg.ef.has_value()) {
1390-
hnsw_search_params.efSearch = hnsw_cfg.ef.value();
1391-
}
1392-
1393-
// do not collect HNSW stats
1394-
hnsw_search_params.hnsw_stats = nullptr;
1395-
// set up feder
1396-
hnsw_search_params.feder = feder_result.get();
1397-
// set up kAlpha
1398-
hnsw_search_params.kAlpha = bitset.filter_ratio() * 0.7f;
1399-
1400-
// set up a selector
1401-
BitsetViewIDSelector bw_idselector(bitset);
1402-
BitsetViewWithMappingIDSelector bw_mapping_idselector(
1403-
bitset, labels.empty() ? nullptr : labels[index_id].get()->data());
1404-
faiss::IDSelector* id_selector = nullptr;
1405-
if (!bitset.empty()) {
1406-
if (labels.empty()) {
1407-
id_selector = &bw_idselector;
1408-
} else {
1409-
id_selector = &bw_mapping_idselector;
1410-
}
1411-
}
1412-
hnsw_search_params.sel = id_selector;
1413-
1414-
////////////////////////////////////////////////////////////////
1415-
// run
1416-
std::vector<std::vector<int64_t>> result_id_array(rows);
1417-
std::vector<std::vector<float>> result_dist_array(rows);
1418-
1419-
std::vector<folly::Future<folly::Unit>> futs;
1420-
futs.reserve(rows);
1421-
1422-
// a sequential version
1423-
for (int64_t i = 0; i < rows; ++i) {
1424-
// const int64_t idx = i;
1425-
// {
1426-
1427-
futs.emplace_back(
1428-
search_pool->push([&, idx = i, is_refined = is_refined, index_wrapper_ptr = index_wrapper_ptr] {
1429-
// 1 thread per element
1430-
ThreadPool::ScopedSearchOmpSetter setter(1);
1431-
1432-
// set up a query
1433-
const float* cur_query = nullptr;
1434-
1435-
std::vector<float> cur_query_tmp(dim);
1436-
if (data_format == DataFormatEnum::fp32) {
1437-
cur_query = (const float*)data + idx * dim;
1438-
} else {
1439-
convert_rows_to_fp32(data, cur_query_tmp.data(), data_format, idx, 1, dim);
1440-
cur_query = cur_query_tmp.data();
1441-
}
1442-
1443-
// initialize a buffer
1444-
faiss::RangeSearchResult res(1);
1445-
1446-
// perform the search
1447-
if (is_refined) {
1448-
faiss::IndexRefineSearchParameters refine_params;
1449-
refine_params.k_factor = hnsw_cfg.refine_k.value_or(1);
1450-
// a refine procedure itself does not need to care about filtering
1451-
refine_params.sel = nullptr;
1452-
refine_params.base_index_params = &hnsw_search_params;
1453-
1454-
index_wrapper_ptr->range_search(1, cur_query, radius, &res, &refine_params);
1455-
} else {
1456-
index_wrapper_ptr->range_search(1, cur_query, radius, &res, &hnsw_search_params);
1457-
}
1458-
1459-
// post-process
1460-
const size_t elem_cnt = res.lims[1];
1461-
result_dist_array[idx].resize(elem_cnt);
1462-
result_id_array[idx].resize(elem_cnt);
1463-
1464-
if (labels.empty()) {
1465-
for (size_t j = 0; j < elem_cnt; j++) {
1466-
result_dist_array[idx][j] = res.distances[j];
1467-
result_id_array[idx][j] = res.labels[j];
1468-
}
1469-
} else {
1470-
for (size_t j = 0; j < elem_cnt; j++) {
1471-
result_dist_array[idx][j] = res.distances[j];
1472-
result_id_array[idx][j] =
1473-
res.labels[j] < 0 ? res.labels[j] : labels[index_id]->operator[](res.labels[j]);
1474-
}
1475-
}
1476-
1477-
if (hnsw_cfg.range_filter.value() != defaultRangeFilter) {
1478-
FilterRangeSearchResultForOneNq(result_dist_array[idx], result_id_array[idx],
1479-
is_similarity_metric, radius, range_filter);
1480-
}
1481-
}));
1482-
}
1483-
1484-
// wait for the completion
1485-
WaitAllSuccess(futs);
1486-
1487-
//
1488-
RangeSearchResult range_search_result =
1489-
GetRangeSearchResult(result_dist_array, result_id_array, is_similarity_metric, rows, radius, range_filter);
1490-
1491-
return GenResultDataSet(rows, std::move(range_search_result));
1492-
}
1493-
14941329
protected:
14951330
DataFormatEnum data_format;
14961331

0 commit comments

Comments
 (0)