99// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
1010// or implied. See the License for the specific language governing permissions and limitations under the License.
1111
12+ #include < faiss/cppcontrib/knowhere/IndexBinaryScalarQuantizer.h>
1213#include < faiss/cppcontrib/knowhere/IndexCosine.h>
1314#include < faiss/cppcontrib/knowhere/IndexFlat.h>
15+ #include < faiss/cppcontrib/knowhere/IndexHNSWBinary.h>
1416#include < faiss/cppcontrib/knowhere/IndexSQ4Uniform.h>
1517#include < faiss/cppcontrib/knowhere/MetricType.h>
1618#include < faiss/cppcontrib/knowhere/impl/CountSizeIOWriter.h>
3234#include " common/metric.h"
3335#include " faiss/cppcontrib/knowhere/IndexHNSW.h"
3436#include " faiss/cppcontrib/knowhere/IndexRefine.h"
35- #include " faiss/cppcontrib/knowhere/impl/ScalarQuantizer.h"
3637#include " faiss/cppcontrib/knowhere/index_io.h"
3738#include " faiss/impl/mapped_io.h"
3839#include " index/clustering_config.h"
@@ -546,10 +547,10 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
546547 // where each query_row has ((dim + 7) / 8) * 8 bits, and the total is nrows * ((dim + 7) / 8) * 8 bits.
547548 // But the final format required is nrows * dim * 32 bits (float).
548549 // There are actually two conversions happening here:
549- // 1. Each uint8_t value must be converted to float (in `BinarySQDistanceComputerWrapper ::set_query`
550- // and `ScalarQuantizer::compute_codes`) , it will be converted back to uint8_t). [same as int8]
550+ // 1. Each uint8_t value must be converted to float (in `BinaryFlatCodesDC ::set_query` inside
551+ // IndexBinaryScalarQuantizer , it will be converted back to uint8_t). [same as int8]
551552 // 2. Each row must occupy dim * 32 bits of space, even if not all bits are filled;
552- // this is required by the convention set in `ScalarQuantizer::compute_codes` .
553+ // this is required by the convention set by IndexBinaryScalarQuantizer::sa_encode .
553554 const knowhere::bin1* const src = reinterpret_cast <const knowhere::bin1*>(src_in);
554555 auto uint8_dim = (dim + 7 ) / 8 ;
555556 for (size_t i = 0 ; i < nrows; i++) {
@@ -711,20 +712,26 @@ get_index_data_format(const faiss::Index* index) {
711712 return DataFormatEnum::fp32;
712713 }
713714
714- // is it sq?
715- // note: IndexScalarQuantizerCosine preserves the original data, no cosine norm is appliesd
716- auto index_sq = dynamic_cast <const faiss::cppcontrib::knowhere::IndexScalarQuantizer*>(index);
717- if (index_sq != nullptr ) {
718- if (index_sq->sq .qtype == faiss::cppcontrib::knowhere::ScalarQuantizer::QT_bf16) {
719- return DataFormatEnum::bf16 ;
720- } else if (index_sq->sq .qtype == faiss::cppcontrib::knowhere::ScalarQuantizer::QT_fp16) {
721- return DataFormatEnum::fp16;
722- } else if (index_sq->sq .qtype == faiss::cppcontrib::knowhere::ScalarQuantizer::QT_8bit_direct_signed) {
723- return DataFormatEnum::int8;
724- } else if (index_sq->sq .qtype == faiss::cppcontrib::knowhere::ScalarQuantizer::QT_1bit_direct) {
725- return DataFormatEnum::bin1;
726- } else {
727- return std::nullopt ;
715+ // is it binary (1-bit-direct)? Routed through
716+ // IndexBinaryScalarQuantizer, which replaces the legacy
717+ // IndexScalarQuantizer(QT_1bit_direct) path.
718+ if (dynamic_cast <const faiss::cppcontrib::knowhere::IndexBinaryScalarQuantizer*>(index) != nullptr ) {
719+ return DataFormatEnum::bin1;
720+ }
721+
722+ // is it sq? All SQ storage produced by knowhere now inherits from
723+ // baseline faiss::IndexScalarQuantizer (Cosine/SQ4U wrappers,
724+ // plain IndexHNSWSQ, and refine).
725+ if (auto * index_sq = dynamic_cast <const faiss::IndexScalarQuantizer*>(index)) {
726+ switch (index_sq->sq .qtype ) {
727+ case faiss::ScalarQuantizer::QT_bf16:
728+ return DataFormatEnum::bf16 ;
729+ case faiss::ScalarQuantizer::QT_fp16:
730+ return DataFormatEnum::fp16;
731+ case faiss::ScalarQuantizer::QT_8bit_direct_signed:
732+ return DataFormatEnum::int8;
733+ default :
734+ return std::nullopt ;
728735 }
729736 }
730737
@@ -2068,9 +2075,8 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
20682075 if (is_binary) {
20692076 if (metric.value () == faiss::MetricType::METRIC_Hamming ||
20702077 metric.value () == faiss::MetricType::METRIC_Jaccard) {
2071- hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQ>(
2072- dim, faiss::cppcontrib::knowhere::ScalarQuantizer::QT_1bit_direct, hnsw_cfg.M .value (),
2073- metric.value ());
2078+ hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWBinary>(dim, hnsw_cfg.M .value (),
2079+ metric.value ());
20742080 } else {
20752081 LOG_KNOWHERE_ERROR_ << " Unsupported metric for binary data: " << hnsw_cfg.metric_type .value ();
20762082 return Status::invalid_metric_type;
@@ -2082,14 +2088,13 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
20822088 std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWFlatCosine>(dim, hnsw_cfg.M .value ());
20832089 } else if (data_format == DataFormatEnum::fp16) {
20842090 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQCosine>(
2085- dim, faiss::cppcontrib::knowhere:: ScalarQuantizer::QT_fp16, hnsw_cfg.M .value ());
2091+ dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M .value ());
20862092 } else if (data_format == DataFormatEnum::bf16 ) {
20872093 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQCosine>(
2088- dim, faiss::cppcontrib::knowhere:: ScalarQuantizer::QT_bf16, hnsw_cfg.M .value ());
2094+ dim, faiss::ScalarQuantizer::QT_bf16, hnsw_cfg.M .value ());
20892095 } else if (data_format == DataFormatEnum::int8) {
20902096 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQCosine>(
2091- dim, faiss::cppcontrib::knowhere::ScalarQuantizer::QT_8bit_direct_signed,
2092- hnsw_cfg.M .value ());
2097+ dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M .value ());
20932098 } else {
20942099 LOG_KNOWHERE_ERROR_ << " Unsupported metric type: " << hnsw_cfg.metric_type .value ();
20952100 return Status::invalid_metric_type;
@@ -2100,16 +2105,13 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
21002105 dim, hnsw_cfg.M .value (), metric.value ());
21012106 } else if (data_format == DataFormatEnum::fp16) {
21022107 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQ>(
2103- dim, faiss::cppcontrib::knowhere::ScalarQuantizer::QT_fp16, hnsw_cfg.M .value (),
2104- metric.value ());
2108+ dim, faiss::ScalarQuantizer::QT_fp16, hnsw_cfg.M .value (), metric.value ());
21052109 } else if (data_format == DataFormatEnum::bf16 ) {
21062110 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQ>(
2107- dim, faiss::cppcontrib::knowhere::ScalarQuantizer::QT_bf16, hnsw_cfg.M .value (),
2108- metric.value ());
2111+ dim, faiss::ScalarQuantizer::QT_bf16, hnsw_cfg.M .value (), metric.value ());
21092112 } else if (data_format == DataFormatEnum::int8) {
21102113 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQ>(
2111- dim, faiss::cppcontrib::knowhere::ScalarQuantizer::QT_8bit_direct_signed,
2112- hnsw_cfg.M .value (), metric.value ());
2114+ dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M .value (), metric.value ());
21132115 } else {
21142116 LOG_KNOWHERE_ERROR_ << " Unsupported metric type: " << hnsw_cfg.metric_type .value ();
21152117 return Status::invalid_metric_type;
@@ -2548,7 +2550,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
25482550
25492551 // create an index
25502552 const bool is_cosine = IsMetricType (hnsw_cfg.metric_type .value (), metric::COSINE);
2551- const bool is_sq4u = sq_type.value () == faiss::cppcontrib::knowhere:: ScalarQuantizer::QT_4bit_uniform;
2553+ const bool is_sq4u = sq_type.value () == faiss::ScalarQuantizer::QT_4bit_uniform;
25522554
25532555 // should refine be used?
25542556 std::unique_ptr<faiss::Index> final_index;
@@ -2570,6 +2572,17 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {
25702572 } else {
25712573 hnsw_index = std::make_unique<faiss::cppcontrib::knowhere::IndexHNSWSQ>(
25722574 dim, sq_type.value (), hnsw_cfg.M .value (), metric.value ());
2575+ // QT_4bit_uniform + L2 benefits from quantile-based range
2576+ // estimation. This used to be hard-coded inside the fork
2577+ // IndexScalarQuantizer ctor; moved here so that ctor is
2578+ // behaviorally equivalent to baseline.
2579+ if (is_sq4u) {
2580+ auto * idx_sq = dynamic_cast <faiss::IndexScalarQuantizer*>(hnsw_index->storage );
2581+ if (idx_sq != nullptr ) {
2582+ idx_sq->sq .rangestat = faiss::ScalarQuantizer::RS_quantiles;
2583+ idx_sq->sq .rangestat_arg = 0.01 ;
2584+ }
2585+ }
25732586 }
25742587
25752588 hnsw_index->hnsw .efConstruction = hnsw_cfg.efConstruction .value ();
0 commit comments