@@ -1326,171 +1326,6 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
13261326 return res;
13271327 }
13281328
1329- expected<DataSetPtr>
1330- RangeSearch (const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
1331- if (this ->indexes .empty ()) {
1332- return expected<DataSetPtr>::Err (Status::empty_index, " index not loaded" );
1333- }
1334- for (const auto & index : indexes) {
1335- if (index == nullptr ) {
1336- return expected<DataSetPtr>::Err (Status::empty_index, " index not loaded" );
1337- }
1338- if (!index->is_trained ) {
1339- return expected<DataSetPtr>::Err (Status::index_not_trained, " index not trained" );
1340- }
1341- }
1342-
1343- const auto dim = dataset->GetDim ();
1344- const auto rows = dataset->GetRows ();
1345- const auto * data = dataset->GetTensor ();
1346-
1347- const auto hnsw_cfg = static_cast <const FaissHnswConfig&>(*cfg);
1348- auto index_id = getIndexToSearchByScalarInfo (hnsw_cfg, bitset);
1349- if (index_id < 0 ) {
1350- return expected<DataSetPtr>::Err (Status::invalid_args, " partition key value not correctly set" );
1351- }
1352-
1353- const bool is_similarity_metric = faiss::is_similarity_metric (indexes[index_id]->metric_type );
1354-
1355- const float radius = hnsw_cfg.radius .value ();
1356- const float range_filter = hnsw_cfg.range_filter .value ();
1357-
1358- feder::hnsw::FederResultUniq feder_result;
1359- if (hnsw_cfg.trace_visit .value ()) {
1360- if (rows != 1 ) {
1361- return expected<DataSetPtr>::Err (Status::invalid_args, " a single query vector is required" );
1362- }
1363- feder_result = std::make_unique<feder::hnsw::FederResult>();
1364- }
1365-
1366- // check for brute-force search
1367- auto whether_bf_search = WhetherPerformBruteForceRangeSearch (indexes[index_id].get (), hnsw_cfg, bitset);
1368-
1369- if (!whether_bf_search.has_value ()) {
1370- return expected<DataSetPtr>::Err (Status::invalid_args, " ef parameter is missing" );
1371- }
1372-
1373- // whether a user wants a refine
1374- const bool whether_to_enable_refine = true ;
1375-
1376- // set up an index wrapper
1377- auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper (
1378- indexes[index_id].get (), hnsw_cfg, whether_bf_search.value_or (false ), whether_to_enable_refine);
1379-
1380- if (index_wrapper == nullptr ) {
1381- return expected<DataSetPtr>::Err (Status::invalid_args, " an input index seems to be unrelated to HNSW" );
1382- }
1383-
1384- faiss::Index* index_wrapper_ptr = index_wrapper.get ();
1385-
1386- // set up faiss search parameters
1387- knowhere::SearchParametersHNSWWrapper hnsw_search_params;
1388-
1389- if (hnsw_cfg.ef .has_value ()) {
1390- hnsw_search_params.efSearch = hnsw_cfg.ef .value ();
1391- }
1392-
1393- // do not collect HNSW stats
1394- hnsw_search_params.hnsw_stats = nullptr ;
1395- // set up feder
1396- hnsw_search_params.feder = feder_result.get ();
1397- // set up kAlpha
1398- hnsw_search_params.kAlpha = bitset.filter_ratio () * 0 .7f ;
1399-
1400- // set up a selector
1401- BitsetViewIDSelector bw_idselector (bitset);
1402- BitsetViewWithMappingIDSelector bw_mapping_idselector (
1403- bitset, labels.empty () ? nullptr : labels[index_id].get ()->data ());
1404- faiss::IDSelector* id_selector = nullptr ;
1405- if (!bitset.empty ()) {
1406- if (labels.empty ()) {
1407- id_selector = &bw_idselector;
1408- } else {
1409- id_selector = &bw_mapping_idselector;
1410- }
1411- }
1412- hnsw_search_params.sel = id_selector;
1413-
1414- // //////////////////////////////////////////////////////////////
1415- // run
1416- std::vector<std::vector<int64_t >> result_id_array (rows);
1417- std::vector<std::vector<float >> result_dist_array (rows);
1418-
1419- std::vector<folly::Future<folly::Unit>> futs;
1420- futs.reserve (rows);
1421-
1422- // a sequential version
1423- for (int64_t i = 0 ; i < rows; ++i) {
1424- // const int64_t idx = i;
1425- // {
1426-
1427- futs.emplace_back (
1428- search_pool->push ([&, idx = i, is_refined = is_refined, index_wrapper_ptr = index_wrapper_ptr] {
1429- // 1 thread per element
1430- ThreadPool::ScopedSearchOmpSetter setter (1 );
1431-
1432- // set up a query
1433- const float * cur_query = nullptr ;
1434-
1435- std::vector<float > cur_query_tmp (dim);
1436- if (data_format == DataFormatEnum::fp32) {
1437- cur_query = (const float *)data + idx * dim;
1438- } else {
1439- convert_rows_to_fp32 (data, cur_query_tmp.data (), data_format, idx, 1 , dim);
1440- cur_query = cur_query_tmp.data ();
1441- }
1442-
1443- // initialize a buffer
1444- faiss::RangeSearchResult res (1 );
1445-
1446- // perform the search
1447- if (is_refined) {
1448- faiss::IndexRefineSearchParameters refine_params;
1449- refine_params.k_factor = hnsw_cfg.refine_k .value_or (1 );
1450- // a refine procedure itself does not need to care about filtering
1451- refine_params.sel = nullptr ;
1452- refine_params.base_index_params = &hnsw_search_params;
1453-
1454- index_wrapper_ptr->range_search (1 , cur_query, radius, &res, &refine_params);
1455- } else {
1456- index_wrapper_ptr->range_search (1 , cur_query, radius, &res, &hnsw_search_params);
1457- }
1458-
1459- // post-process
1460- const size_t elem_cnt = res.lims [1 ];
1461- result_dist_array[idx].resize (elem_cnt);
1462- result_id_array[idx].resize (elem_cnt);
1463-
1464- if (labels.empty ()) {
1465- for (size_t j = 0 ; j < elem_cnt; j++) {
1466- result_dist_array[idx][j] = res.distances [j];
1467- result_id_array[idx][j] = res.labels [j];
1468- }
1469- } else {
1470- for (size_t j = 0 ; j < elem_cnt; j++) {
1471- result_dist_array[idx][j] = res.distances [j];
1472- result_id_array[idx][j] =
1473- res.labels [j] < 0 ? res.labels [j] : labels[index_id]->operator [](res.labels [j]);
1474- }
1475- }
1476-
1477- if (hnsw_cfg.range_filter .value () != defaultRangeFilter) {
1478- FilterRangeSearchResultForOneNq (result_dist_array[idx], result_id_array[idx],
1479- is_similarity_metric, radius, range_filter);
1480- }
1481- }));
1482- }
1483-
1484- // wait for the completion
1485- WaitAllSuccess (futs);
1486-
1487- //
1488- RangeSearchResult range_search_result =
1489- GetRangeSearchResult (result_dist_array, result_id_array, is_similarity_metric, rows, radius, range_filter);
1490-
1491- return GenResultDataSet (rows, std::move (range_search_result));
1492- }
1493-
14941329 protected:
14951330 DataFormatEnum data_format;
14961331
0 commit comments