|
40 | 40 | #include "index/hnsw/impl/IndexConditionalWrapper.h" |
41 | 41 | #include "index/hnsw/impl/IndexHNSWWrapper.h" |
42 | 42 | #include "index/hnsw/impl/IndexWrapperCosine.h" |
| 43 | +#include "index/refine/refine_utils.h" |
43 | 44 | #include "io/memory_io.h" |
44 | 45 | #include "knowhere/bitsetview_idselector.h" |
45 | 46 | #include "knowhere/comp/index_param.h" |
@@ -2034,205 +2035,6 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWI |
2034 | 2035 | } |
2035 | 2036 | }; |
2036 | 2037 |
|
2037 | | -namespace { |
2038 | | - |
2039 | | -// a supporting function |
2040 | | -expected<faiss::ScalarQuantizer::QuantizerType> |
2041 | | -get_sq_quantizer_type(const std::string& sq_type) { |
2042 | | - std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = { |
2043 | | - {"sq6", faiss::ScalarQuantizer::QT_6bit}, |
2044 | | - {"sq8", faiss::ScalarQuantizer::QT_8bit}, |
2045 | | - {"fp16", faiss::ScalarQuantizer::QT_fp16}, |
2046 | | - {"bf16", faiss::ScalarQuantizer::QT_bf16}, |
2047 | | - {"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}}; |
2048 | | - |
2049 | | - // todo: tolower |
2050 | | - auto sq_type_tolower = str_to_lower(sq_type); |
2051 | | - auto itr = sq_types.find(sq_type_tolower); |
2052 | | - if (itr == sq_types.cend()) { |
2053 | | - return expected<faiss::ScalarQuantizer::QuantizerType>::Err( |
2054 | | - Status::invalid_args, fmt::format("invalid scalar quantizer type ({})", sq_type_tolower)); |
2055 | | - } |
2056 | | - |
2057 | | - return itr->second; |
2058 | | -} |
2059 | | - |
2060 | | -/* |
2061 | | -// checks whether an index contains a refiner, suitable for a given data format |
2062 | | -std::optional<bool> whether_refine_is_datatype( |
2063 | | - const faiss::Index* index, |
2064 | | - const DataFormatEnum data_format |
2065 | | -) { |
2066 | | - if (index == nullptr) { |
2067 | | - return {}; |
2068 | | - } |
2069 | | -
|
2070 | | - const faiss::IndexRefine* const index_refine = dynamic_cast<const faiss::IndexRefine*>(index); |
2071 | | - if (index_refine == nullptr) { |
2072 | | - return false; |
2073 | | - } |
2074 | | -
|
2075 | | - switch(data_format) { |
2076 | | - case DataFormatEnum::fp32: |
2077 | | - return (dynamic_cast<const faiss::IndexFlat*>(index_refine->refine_index) != nullptr); |
2078 | | - case DataFormatEnum::fp16: |
2079 | | - { |
2080 | | - const auto* const index_sq = dynamic_cast<const |
2081 | | -faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype == |
2082 | | -faiss::ScalarQuantizer::QT_fp16); |
2083 | | - } |
2084 | | - case DataFormatEnum::bf16: |
2085 | | - { |
2086 | | - const auto* const index_sq = dynamic_cast<const |
2087 | | -faiss::IndexScalarQuantizer*>(index_refine->refine_index); return (index_sq != nullptr && index_sq->sq.qtype == |
2088 | | -faiss::ScalarQuantizer::QT_bf16); |
2089 | | - } |
2090 | | - default: |
2091 | | - return {}; |
2092 | | - } |
2093 | | -} |
2094 | | -*/ |
2095 | | - |
2096 | | -expected<bool> |
2097 | | -is_flat_refine(const std::optional<std::string>& refine_type) { |
2098 | | - // grab a type of a refine index |
2099 | | - if (!refine_type.has_value()) { |
2100 | | - return true; |
2101 | | - }; |
2102 | | - |
2103 | | - // todo: tolower |
2104 | | - std::string refine_type_tolower = str_to_lower(refine_type.value()); |
2105 | | - if (refine_type_tolower == "fp32" || refine_type_tolower == "flat") { |
2106 | | - return true; |
2107 | | - }; |
2108 | | - |
2109 | | - // parse |
2110 | | - auto refine_sq_type = get_sq_quantizer_type(refine_type_tolower); |
2111 | | - if (!refine_sq_type.has_value()) { |
2112 | | - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); |
2113 | | - return expected<bool>::Err(Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); |
2114 | | - } |
2115 | | - |
2116 | | - return false; |
2117 | | -} |
2118 | | - |
2119 | | -bool |
2120 | | -has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_type, DataFormatEnum dataFormat) { |
2121 | | - if (!quant_type.has_value()) { |
2122 | | - return false; |
2123 | | - } |
2124 | | - |
2125 | | - auto quant = quant_type.value(); |
2126 | | - switch (dataFormat) { |
2127 | | - case DataFormatEnum::fp32: |
2128 | | - return false; |
2129 | | - case DataFormatEnum::fp16: |
2130 | | - return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16; |
2131 | | - case DataFormatEnum::bf16: |
2132 | | - return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16; |
2133 | | - case DataFormatEnum::int8: |
2134 | | - return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed; |
2135 | | - default: |
2136 | | - return false; |
2137 | | - } |
2138 | | -} |
2139 | | - |
2140 | | -bool |
2141 | | -has_lossless_refine_index(const FaissHnswConfig& hnsw_cfg, DataFormatEnum dataFormat) { |
2142 | | - bool has_refine = hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value(); |
2143 | | - if (has_refine) { |
2144 | | - expected<bool> flat_refine = is_flat_refine(hnsw_cfg.refine_type); |
2145 | | - if (flat_refine.has_value() && flat_refine.value()) { |
2146 | | - return true; |
2147 | | - } |
2148 | | - |
2149 | | - auto sq_refine_type = get_sq_quantizer_type(hnsw_cfg.refine_type.value()); |
2150 | | - return has_lossless_quant(sq_refine_type, dataFormat); |
2151 | | - } |
2152 | | - return false; |
2153 | | -} |
2154 | | - |
2155 | | -// pick a refine index |
2156 | | -expected<std::unique_ptr<faiss::Index>> |
2157 | | -pick_refine_index(const DataFormatEnum data_format, const std::optional<std::string>& refine_type, |
2158 | | - std::unique_ptr<faiss::IndexHNSW>&& hnsw_index) { |
2159 | | - // yes |
2160 | | - |
2161 | | - // grab a type of a refine index |
2162 | | - expected<bool> is_fp32_flat = is_flat_refine(refine_type); |
2163 | | - if (!is_fp32_flat.has_value()) { |
2164 | | - return expected<std::unique_ptr<faiss::Index>>::Err(Status::invalid_args, ""); |
2165 | | - } |
2166 | | - |
2167 | | - const bool is_fp32_flat_v = is_fp32_flat.value(); |
2168 | | - |
2169 | | - // check input data_format |
2170 | | - if (data_format == DataFormatEnum::fp16) { |
2171 | | - // make sure that we're using fp16 refine |
2172 | | - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); |
2173 | | - if (!(refine_sq_type.has_value() && |
2174 | | - (refine_sq_type.value() != faiss::ScalarQuantizer::QT_bf16 && !is_fp32_flat_v))) { |
2175 | | - LOG_KNOWHERE_ERROR_ << "fp16 input data does not accept bf16 or fp32 as a refine index."; |
2176 | | - return expected<std::unique_ptr<faiss::Index>>::Err( |
2177 | | - Status::invalid_args, "fp16 input data does not accept bf16 or fp32 as a refine index."); |
2178 | | - } |
2179 | | - } |
2180 | | - |
2181 | | - if (data_format == DataFormatEnum::bf16) { |
2182 | | - // make sure that we're using bf16 refine |
2183 | | - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); |
2184 | | - if (!(refine_sq_type.has_value() && |
2185 | | - (refine_sq_type.value() != faiss::ScalarQuantizer::QT_fp16 && !is_fp32_flat_v))) { |
2186 | | - LOG_KNOWHERE_ERROR_ << "bf16 input data does not accept fp16 or fp32 as a refine index."; |
2187 | | - return expected<std::unique_ptr<faiss::Index>>::Err( |
2188 | | - Status::invalid_args, "bf16 input data does not accept fp16 or fp32 as a refine index."); |
2189 | | - } |
2190 | | - } |
2191 | | - |
2192 | | - // build |
2193 | | - std::unique_ptr<faiss::IndexHNSW> local_hnsw_index = std::move(hnsw_index); |
2194 | | - |
2195 | | - // either build flat or sq |
2196 | | - if (is_fp32_flat_v) { |
2197 | | - // build IndexFlat as a refine |
2198 | | - auto refine_index = std::make_unique<faiss::IndexRefineFlat>(local_hnsw_index.get()); |
2199 | | - |
2200 | | - // let refine_index to own everything |
2201 | | - refine_index->own_fields = true; |
2202 | | - local_hnsw_index.release(); |
2203 | | - |
2204 | | - // reassign |
2205 | | - return refine_index; |
2206 | | - } else { |
2207 | | - // being IndexScalarQuantizer as a refine |
2208 | | - auto refine_sq_type = get_sq_quantizer_type(refine_type.value()); |
2209 | | - |
2210 | | - // a redundant check |
2211 | | - if (!refine_sq_type.has_value()) { |
2212 | | - LOG_KNOWHERE_ERROR_ << "Invalid refine type: " << refine_type.value(); |
2213 | | - return expected<std::unique_ptr<faiss::Index>>::Err( |
2214 | | - Status::invalid_args, fmt::format("invalid refine type ({})", refine_type.value())); |
2215 | | - } |
2216 | | - |
2217 | | - // create an sq |
2218 | | - auto sq_refine = std::make_unique<faiss::IndexScalarQuantizer>( |
2219 | | - local_hnsw_index->storage->d, refine_sq_type.value(), local_hnsw_index->storage->metric_type); |
2220 | | - |
2221 | | - auto refine_index = std::make_unique<faiss::IndexRefine>(local_hnsw_index.get(), sq_refine.get()); |
2222 | | - |
2223 | | - // let refine_index to own everything |
2224 | | - refine_index->own_refine_index = true; |
2225 | | - refine_index->own_fields = true; |
2226 | | - local_hnsw_index.release(); |
2227 | | - sq_refine.release(); |
2228 | | - |
2229 | | - // reassign |
2230 | | - return refine_index; |
2231 | | - } |
2232 | | -} |
2233 | | - |
2234 | | -} // namespace |
2235 | | - |
2236 | 2038 | // |
2237 | 2039 | class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { |
2238 | 2040 | public: |
@@ -2300,7 +2102,10 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { |
2300 | 2102 |
|
2301 | 2103 | if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { |
2302 | 2104 | // yes |
2303 | | - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); |
| 2105 | + const auto hnsw_d = hnsw_index->storage->d; |
| 2106 | + const auto hnsw_metric_type = hnsw_index->storage->metric_type; |
| 2107 | + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), |
| 2108 | + hnsw_d, hnsw_metric_type); |
2304 | 2109 | if (!final_index_cnd.has_value()) { |
2305 | 2110 | return Status::invalid_args; |
2306 | 2111 | } |
@@ -2368,7 +2173,7 @@ class BaseFaissRegularIndexHNSWSQNodeTemplate : public BaseFaissRegularIndexHNSW |
2368 | 2173 | return true; |
2369 | 2174 | } |
2370 | 2175 |
|
2371 | | - return has_lossless_refine_index(hnsw_sq_cfg, datatype_v<DataType>); |
| 2176 | + return has_lossless_refine_index(hnsw_sq_cfg.refine, hnsw_sq_cfg.refine_type, datatype_v<DataType>); |
2372 | 2177 | } |
2373 | 2178 | }; |
2374 | 2179 |
|
@@ -2449,7 +2254,10 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { |
2449 | 2254 | std::unique_ptr<faiss::Index> final_index; |
2450 | 2255 | if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { |
2451 | 2256 | // yes |
2452 | | - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); |
| 2257 | + const auto hnsw_d = hnsw_index->storage->d; |
| 2258 | + const auto hnsw_metric_type = hnsw_index->storage->metric_type; |
| 2259 | + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), |
| 2260 | + hnsw_d, hnsw_metric_type); |
2453 | 2261 | if (!final_index_cnd.has_value()) { |
2454 | 2262 | return Status::invalid_args; |
2455 | 2263 | } |
@@ -2640,7 +2448,7 @@ class BaseFaissRegularIndexHNSWPQNodeTemplate : public BaseFaissRegularIndexHNSW |
2640 | 2448 | static bool |
2641 | 2449 | StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) { |
2642 | 2450 | auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config); |
2643 | | - return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>); |
| 2451 | + return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>); |
2644 | 2452 | } |
2645 | 2453 | }; |
2646 | 2454 |
|
@@ -2728,7 +2536,10 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { |
2728 | 2536 | std::unique_ptr<faiss::Index> final_index; |
2729 | 2537 | if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { |
2730 | 2538 | // yes |
2731 | | - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); |
| 2539 | + const auto hnsw_d = hnsw_index->storage->d; |
| 2540 | + const auto hnsw_metric_type = hnsw_index->storage->metric_type; |
| 2541 | + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index), |
| 2542 | + hnsw_d, hnsw_metric_type); |
2732 | 2543 | if (!final_index_cnd.has_value()) { |
2733 | 2544 | return Status::invalid_args; |
2734 | 2545 | } |
@@ -2920,7 +2731,7 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS |
2920 | 2731 | static bool |
2921 | 2732 | StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) { |
2922 | 2733 | auto hnsw_cfg = static_cast<const FaissHnswConfig&>(config); |
2923 | | - return has_lossless_refine_index(hnsw_cfg, datatype_v<DataType>); |
| 2734 | + return has_lossless_refine_index(hnsw_cfg.refine, hnsw_cfg.refine_type, datatype_v<DataType>); |
2924 | 2735 | } |
2925 | 2736 | }; |
2926 | 2737 |
|
|
0 commit comments