diff --git a/cmake/GetHNSW.cmake b/cmake/GetHNSW.cmake index 2b45e3e3..476d6398 100644 --- a/cmake/GetHNSW.cmake +++ b/cmake/GetHNSW.cmake @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -set ( HNSW_GITHUB "https://github.com/manticoresoftware/hnswlib/archive/6568d3b.zip" ) +set ( HNSW_GITHUB "https://github.com/manticoresoftware/hnswlib/archive/091f3dd.zip" ) set ( HNSW_BUNDLEZIP "${LIBS_BUNDLE}/hnswlib-0.7.0.tar.gz" ) cmake_minimum_required ( VERSION 3.17 FATAL_ERROR ) diff --git a/knn/knn.cpp b/knn/knn.cpp index 379cc640..d130a0f2 100644 --- a/knn/knn.cpp +++ b/knn/knn.cpp @@ -310,6 +310,48 @@ class L2BinarySIMD16ResidualsDistFn_c : public DistFnDispatch_c<&L2BinaryFloatDi }; #endif +// build-mode DistFn classes +using IPBinaryGenericBuildDistFn_c = DistFnDispatch_c<&IPBinaryFloatDistanceGenericBuild>; +using L2BinaryGenericBuildDistFn_c = DistFnDispatch_c<&L2BinaryFloatDistanceGenericBuild>; + +#if !defined(USE_SIMDE) +class IPBinarySIMD16BuildDistFn_c : public DistFnDispatch_c<&IPBinaryFloatDistanceSIMD16Build> +{ +public: + static void Eval2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) + { + IPBinaryFloatDistanceSIMD16Batch2Build ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + } +}; + +class IPBinarySIMD16ResidualsBuildDistFn_c : public DistFnDispatch_c<&IPBinaryFloatDistanceSIMD16ResidualsBuild> +{ +public: + static void Eval2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) + { + IPBinaryFloatDistanceSIMD16ResidualsBatch2Build ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + } +}; + +class L2BinarySIMD16BuildDistFn_c : public DistFnDispatch_c<&L2BinaryFloatDistanceSIMD16Build> +{ +public: + static void Eval2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) + { + L2BinaryFloatDistanceSIMD16Batch2Build ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + } +}; + +class L2BinarySIMD16ResidualsBuildDistFn_c : public DistFnDispatch_c<&L2BinaryFloatDistanceSIMD16ResidualsBuild> +{ +public: + static void Eval2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) + { + L2BinaryFloatDistanceSIMD16ResidualsBatch2Build ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + } +}; +#endif + template static void RunSearchPath ( const hnswlib::HierarchicalNSW & tAlg, std::vector & dResults, const void * pData, int64_t iResults, HNSWFilterWrapper_c * pFilter, size_t * pSearchEf, int iSearchPath ) { @@ -359,7 +401,8 @@ void HNSWIndex_c::Search ( std::vector & dResults, const Span_TEncode ( 0, dData, dQuantized ); + std::vector dUnusedQuantizedForQuery; + m_pQuantizer->Encode ( 0, dData, dQuantized, dUnusedQuantizedForQuery ); pData = dQuantized.data(); } @@ -536,7 +579,8 @@ class HNSWIndexBuilder_i virtual ~HNSWIndexBuilder_i() = default; virtual void Train ( const util::Span_T & dData ) = 0; - virtual bool AddDoc ( uint32_t uRowID, const util::Span_T & dData, std::string & sError ) = 0; + virtual bool FinalizeTraining ( std::string & sError ) = 0; + virtual bool AddDoc ( uint32_t uRowID, const util::Span_T & dData, BuildContext_t & tBuildCtx, std::string & sError ) = 0; virtual void Save ( FileWriter_c & tWriter ) = 0; virtual const AttrWithSettings_t & GetAttr() const = 0; virtual const QuantizationSettings_t & GetQuantizationSettings() const = 0; @@ -549,28 +593,56 @@ class HNSWIndexBuilder_c : public HNSWIndexBuilder_i, public HNSWDist_c HNSWIndexBuilder_c ( const AttrWithSettings_t & tAttr, int64_t iNumElements, ScalarQuantizer_i * pQuantizer ); void Train ( const util::Span_T & dData ) override; - bool AddDoc ( uint32_t uRowID, const util::Span_T & dData, std::string & sError ) override; + bool FinalizeTraining ( std::string & sError ) override; + bool AddDoc ( uint32_t uRowID, const util::Span_T & dData, BuildContext_t & tBuildCtx, std::string & sError ) override; void Save ( FileWriter_c & tWriter ) override; const AttrWithSettings_t & GetAttr() const override { return m_tAttr; } const QuantizationSettings_t & GetQuantizationSettings() const override { return m_pQuantizer->GetSettings(); } private: - AttrWithSettings_t m_tAttr; - bool m_bFirstDoc = true; - SpanResizeable_T m_dNormalized; - std::vector m_dQuantized; - std::unique_ptr m_pQuantizer; - std::unique_ptr> m_pAlg; + using AddPoint_fn = void (*) ( hnswlib::HierarchicalNSW &, const void *, uint32_t ); + + template + static void AddPointTyped ( hnswlib::HierarchicalNSW & tAlg, const void * pVec, uint32_t uRowID ) { tAlg.template addPoint ( pVec, (size_t)uRowID, -1 ); } + static void AddPointFallback ( hnswlib::HierarchicalNSW & tAlg, const void * pVec, uint32_t uRowID ) { tAlg.addPoint ( pVec, (size_t)uRowID ); } + AddPoint_fn SelectAddPointFn() const; + + AttrWithSettings_t m_tAttr; + std::unique_ptr m_pQuantizer; + std::unique_ptr> m_pAlg; + AddPoint_fn m_fnAddPoint = AddPointFallback; }; +HNSWIndexBuilder_c::AddPoint_fn HNSWIndexBuilder_c::SelectAddPointFn() const +{ + switch ( m_pSpace->GetDistFuncId() ) + { + case DistFuncId_e::IP_FLOAT32: return AddPointTyped; + case DistFuncId_e::L2_FLOAT32: return AddPointTyped; + case DistFuncId_e::IP_BINARY_GENERIC: return AddPointTyped; + case DistFuncId_e::L2_BINARY_GENERIC: return AddPointTyped; + +#if !defined(USE_SIMDE) + case DistFuncId_e::IP_BINARY_SIMD16: return AddPointTyped; + case DistFuncId_e::IP_BINARY_SIMD16_RESIDUALS: return AddPointTyped; + case DistFuncId_e::L2_BINARY_SIMD16: return AddPointTyped; + case DistFuncId_e::L2_BINARY_SIMD16_RESIDUALS: return AddPointTyped; +#endif + + default: + return AddPointFallback; + } +} + + HNSWIndexBuilder_c::HNSWIndexBuilder_c ( const AttrWithSettings_t & tAttr, int64_t iNumElements, ScalarQuantizer_i * pQuantizer ) : HNSWDist_c ( tAttr.m_iDims, tAttr.m_eHNSWSimilarity, tAttr.m_eQuantization, true ) , m_tAttr ( tAttr ) , m_pQuantizer ( pQuantizer ) { m_pAlg = std::make_unique>( m_pSpace.get(), iNumElements, m_tAttr.m_iHNSWM, m_tAttr.m_iHNSWEFConstruction ); - m_dNormalized.resize ( tAttr.m_iDims ); + m_fnAddPoint = SelectAddPointFn(); } @@ -581,39 +653,51 @@ void HNSWIndexBuilder_c::Train ( const util::Span_T & dData ) } -bool HNSWIndexBuilder_c::AddDoc ( uint32_t uRowID, const util::Span_T & dData, std::string & sError ) +bool HNSWIndexBuilder_c::FinalizeTraining ( std::string & sError ) { - if ( dData.size()!=m_tAttr.m_iDims ) + if ( !m_pQuantizer ) + return true; + + if ( m_pQuantizer->IsFinalized() ) + return true; + + if ( !m_pQuantizer->FinalizeTraining ( sError ) ) + return false; + + m_pSpace->SetQuantizationSettings ( *m_pQuantizer ); + return true; +} + + +bool HNSWIndexBuilder_c::AddDoc ( uint32_t uRowID, const util::Span_T & dData, BuildContext_t & tBuildCtx, std::string & sError ) +{ + if ( dData.size()!=(size_t)m_tAttr.m_iDims ) { sError = FormatStr ( "HNSW error: data has %llu values, index '%s' needs %d values", dData.size(), m_tAttr.m_sName.c_str(), m_tAttr.m_iDims ); return false; } + assert ( !m_pQuantizer || m_pQuantizer->IsFinalized() ); + Span_T dToAdd = dData; if ( m_tAttr.m_eHNSWSimilarity==HNSWSimilarity_e::COSINE ) { - memcpy ( m_dNormalized.data(), dData.data(), dData.size()*sizeof(dData[0] ) ); - VecNormalize(m_dNormalized); - dToAdd = m_dNormalized; + tBuildCtx.m_dNormalized.resize ( dData.size() ); + memcpy ( tBuildCtx.m_dNormalized.data(), dData.data(), dData.size()*sizeof(dData[0] ) ); + VecNormalize ( tBuildCtx.m_dNormalized ); + dToAdd = tBuildCtx.m_dNormalized; } + const void * pVec = nullptr; if ( m_pQuantizer ) { - if ( m_bFirstDoc ) - { - m_bFirstDoc = false; - - if ( !m_pQuantizer->FinalizeTraining(sError) ) - return false; - - m_pSpace->SetQuantizationSettings ( *m_pQuantizer ); - } - - m_pQuantizer->Encode ( uRowID, dToAdd, m_dQuantized ); - m_pAlg->addPoint ( (void*)m_dQuantized.data(), (size_t)uRowID ); + m_pQuantizer->Encode ( uRowID, dToAdd, tBuildCtx.m_dQuantized, tBuildCtx.m_dQuantizedForQuery ); + pVec = (void*)tBuildCtx.m_dQuantized.data(); } else - m_pAlg->addPoint ( (void*)dToAdd.data(), (size_t)uRowID ); + pVec = (void*)dToAdd.data(); + + m_fnAddPoint ( *m_pAlg, pVec, uRowID ); return true; } @@ -634,14 +718,13 @@ class HNSWBuilder_c : public Builder_i public: HNSWBuilder_c ( const Schema_t & tSchema, int64_t iNumElements, const std::string & sTmpFilename ); - void Train ( int iAttr, uint32_t uRowID, const util::Span_T & dData ) override { m_dIndexes[iAttr]->Train(dData); } - bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T & dData ) override { return m_dIndexes[iAttr]->AddDoc ( uRowID, dData, m_sError ); } + void Train ( int iAttr, uint32_t uRowID, const util::Span_T & dData ) override { m_dIndexes[iAttr]->Train(dData); } + bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T & dData, BuildContext_t & tBuildCtx ) override { return m_dIndexes[iAttr]->AddDoc ( uRowID, dData, tBuildCtx, tBuildCtx.m_sError ); } + bool FinalizeTraining ( std::string & sError ) override; bool Save ( const std::string & sFilename, size_t tBufferSize, std::string & sError ) override; - const std::string & GetError() const override { return m_sError; } private: std::vector> m_dIndexes; - std::string m_sError; }; @@ -653,6 +736,16 @@ HNSWBuilder_c::HNSWBuilder_c ( const Schema_t & tSchema, int64_t iNumElements, c } +bool HNSWBuilder_c::FinalizeTraining ( std::string & sError ) +{ + for ( auto & i : m_dIndexes ) + if ( !i->FinalizeTraining(sError) ) + return false; + + return true; +} + + bool HNSWBuilder_c::Save ( const std::string & sFilename, size_t tBufferSize, std::string & sError ) { FileWriter_c tWriter; diff --git a/knn/knn.h b/knn/knn.h index 5eb92737..dea64057 100644 --- a/knn/knn.h +++ b/knn/knn.h @@ -26,7 +26,7 @@ namespace knn { -static const int LIB_VERSION = 13; +static const int LIB_VERSION = 14; static const uint32_t STORAGE_VERSION = 3; enum class HNSWSimilarity_e @@ -122,15 +122,24 @@ class KNN_i virtual bool ShouldUseFullscan ( const std::string & sName, int64_t iResults, int iEf, int64_t iFilterCount ) = 0; }; +// passed via SetAttr so the builder itself holds no per-row mutable state +struct BuildContext_t +{ + util::SpanResizeable_T m_dNormalized; + std::vector m_dQuantized; + std::vector m_dQuantizedForQuery; // 4-bit transposed representation, produced only by the BIT1 binary quantizer during BUILD mode + std::string m_sError; +}; + class Builder_i { public: virtual ~Builder_i() = default; virtual void Train ( int iAttr, uint32_t uRowID, const util::Span_T & dData ) = 0; - virtual bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T & dData ) = 0; + virtual bool SetAttr ( int iAttr, uint32_t uRowID, const util::Span_T & dData, BuildContext_t & tBuildCtx ) = 0; + virtual bool FinalizeTraining ( std::string & sError ) = 0; virtual bool Save ( const std::string & sFilename, size_t tBufferSize, std::string & sError ) = 0; - virtual const std::string & GetError() const = 0; }; class TextToEmbeddings_i diff --git a/knn/quantizer.cpp b/knn/quantizer.cpp index 071d7342..e7c5962a 100644 --- a/knn/quantizer.cpp +++ b/knn/quantizer.cpp @@ -50,7 +50,8 @@ class ScalarQuantizer8Bit_c : public ScalarQuantizer_i void Train ( const Span_T & dPoint ) override; bool FinalizeTraining ( std::string & sError ) override; - void Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized ) override; + bool IsFinalized() const override { return m_bFinalized; } + void Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized, std::vector & dQuantizedForQuery ) override; void FinalizeEncoding() override {} const QuantizationSettings_t & GetSettings() override; std::function GetPoolFetcher() const override { return nullptr; } @@ -138,7 +139,7 @@ bool ScalarQuantizer8Bit_c::FinalizeTraining ( std::string & sError ) } -void ScalarQuantizer8Bit_c::Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized ) +void ScalarQuantizer8Bit_c::Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized, std::vector & /*dQuantizedForQuery*/ ) { assert(m_bFinalized); @@ -183,11 +184,11 @@ class ScalarQuantizer1Bit_c : public ScalarQuantizer8Bit_c using ScalarQuantizer8Bit_c::ScalarQuantizer8Bit_c; public: - void Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized ) override; + void Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized, std::vector & dQuantizedForQuery ) override; }; -void ScalarQuantizer1Bit_c::Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized ) +void ScalarQuantizer1Bit_c::Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized, std::vector & /*dQuantizedForQuery*/ ) { assert(m_bFinalized); @@ -647,7 +648,8 @@ class ScalarQuantizerBinary_T : public ScalarQuantizer_i void Train ( const Span_T & dPoint ) override; bool FinalizeTraining ( std::string & sError ) override; - void Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized ) override; + bool IsFinalized() const override { return m_bFinalized; } + void Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized, std::vector & dQuantizedForQuery ) override; void FinalizeEncoding() override; const QuantizationSettings_t & GetSettings() override; @@ -659,7 +661,6 @@ class ScalarQuantizerBinary_T : public ScalarQuantizer_i HNSWSimilarity_e m_eSimilarity = HNSWSimilarity_e::COSINE; std::string m_sTmpFilename; std::vector m_dCentroid64; - std::vector m_dQuantizedForQuery; MappedBuffer_T m_tBuffer4Bit; size_t m_uDim = 0; bool m_bFinalized = false; @@ -708,16 +709,17 @@ void ScalarQuantizerBinary_T::Train ( const Span_T & dPoint ) } template -void ScalarQuantizerBinary_T::Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized ) +void ScalarQuantizerBinary_T::Encode ( uint32_t uRowID, const Span_T & dPoint, std::vector & dQuantized, std::vector & dQuantizedForQuery ) { assert(m_bFinalized); - m_pQuantizer->Quantize4Bit ( dPoint, m_tSettings.m_dCentroid, BUILD ? m_dQuantizedForQuery : dQuantized ); + m_pQuantizer->Quantize4Bit ( dPoint, m_tSettings.m_dCentroid, BUILD ? dQuantizedForQuery : dQuantized ); if constexpr ( !BUILD ) return; - int64_t iOffset = (int64_t)uRowID * m_dQuantizedForQuery.size(); - memcpy ( m_tBuffer4Bit.data() + iOffset, m_dQuantizedForQuery.data(), m_dQuantizedForQuery.size() ); + assert ( dQuantizedForQuery.size() == m_uQuantized4BitEntrySize ); + int64_t iOffset = (int64_t)uRowID * dQuantizedForQuery.size(); + memcpy ( m_tBuffer4Bit.data() + iOffset, dQuantizedForQuery.data(), dQuantizedForQuery.size() ); m_pQuantizer->Quantize1Bit ( dPoint, m_tSettings.m_dCentroid, dQuantized ); } @@ -766,20 +768,30 @@ bool ScalarQuantizerBinary_T::FinalizeTraining ( std::string & sError ) if ( m_bFinalized ) return true; - m_bFinalized = true; - if ( !m_uTrainedVecs ) + { + m_bFinalized = true; return true; + } - for ( auto & i : m_dCentroid64 ) - m_tSettings.m_dCentroid.push_back ( i/m_uTrainedVecs ); + if ( m_tSettings.m_dCentroid.empty() ) + { + m_tSettings.m_dCentroid.reserve ( m_dCentroid64.size() ); + for ( auto & i : m_dCentroid64 ) + m_tSettings.m_dCentroid.push_back ( i/m_uTrainedVecs ); + } - m_pQuantizer = std::make_unique ( m_uDim, m_eSimilarity ); + if ( !m_pQuantizer ) + m_pQuantizer = std::make_unique ( m_uDim, m_eSimilarity ); - // quantize a fake vector to get quantized size - std::vector dTmp ( m_uDim, 0.0f ); - m_pQuantizer->Quantize4Bit ( dTmp, m_tSettings.m_dCentroid, m_dQuantizedForQuery ); - m_uQuantized4BitEntrySize = m_dQuantizedForQuery.size(); + if ( !m_uQuantized4BitEntrySize ) + { + // quantize a fake vector to get quantized size + std::vector dTmp ( m_uDim, 0.0f ); + std::vector dSizeProbe; + m_pQuantizer->Quantize4Bit ( dTmp, m_tSettings.m_dCentroid, dSizeProbe ); + m_uQuantized4BitEntrySize = dSizeProbe.size(); + } FILE * pFile = fopen ( m_sTmpFilename.c_str(), "wb" ); if ( !pFile ) @@ -793,7 +805,11 @@ bool ScalarQuantizerBinary_T::FinalizeTraining ( std::string & sError ) fwrite ( "", 1, 1, pFile ); fclose ( pFile ); - return m_tBuffer4Bit.Open ( m_sTmpFilename.c_str(), true, sError ); + if ( !m_tBuffer4Bit.Open ( m_sTmpFilename.c_str(), true, sError ) ) + return false; + + m_bFinalized = true; + return true; } /////////////////////////////////////////////////////////////////////////////// diff --git a/knn/quantizer.h b/knn/quantizer.h index 72c94646..459678d0 100644 --- a/knn/quantizer.h +++ b/knn/quantizer.h @@ -67,7 +67,8 @@ class ScalarQuantizer_i virtual void Train ( const util::Span_T & dPoint ) = 0; virtual bool FinalizeTraining ( std::string & sError ) = 0; - virtual void Encode ( uint32_t uRowID, const util::Span_T & dPoint, std::vector & dQuantized ) = 0; + virtual bool IsFinalized () const = 0; + virtual void Encode ( uint32_t uRowID, const util::Span_T & dPoint, std::vector & dQuantized, std::vector & dQuantizedForQuery ) = 0; virtual void FinalizeEncoding() = 0; virtual const QuantizationSettings_t & GetSettings() = 0; diff --git a/knn/space.cpp b/knn/space.cpp index cb783a72..5b31fc64 100644 --- a/knn/space.cpp +++ b/knn/space.cpp @@ -1138,7 +1138,7 @@ static FORCE_INLINE float L2BinaryFloatDistanceFromHammingDist ( const Binary4Bi template -static float IPBinaryFloatDistance ( const void * __restrict pVect1, const void * __restrict pVect2, size_t uRowID1, size_t uRowID2, const void * __restrict pParam ) +FORCE_INLINE static float IPBinaryFloatDistance ( const void * __restrict pVect1, const void * __restrict pVect2, size_t uRowID1, size_t uRowID2, const void * __restrict pParam ) { const auto & tBinaryParam = *(const DistFuncParamBinary_t*)pParam; @@ -1173,7 +1173,7 @@ static float IPBinaryFloatDistance ( const void * __restrict pVect1, const void // in org.elasticsearch.index.codec.vectors.es816.ES816BinaryFlatVectorsScorer // Permalink: https://github.com/elastic/elasticsearch/blob/1dd41ec2b683a7b7c9c16af404b842cf85cbd5bc/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java template -static float L2BinaryFloatDistance ( const void * __restrict pVect1, const void * __restrict pVect2, size_t uRowID1, size_t uRowID2, const void * __restrict pParam ) +FORCE_INLINE static float L2BinaryFloatDistance ( const void * __restrict pVect1, const void * __restrict pVect2, size_t uRowID1, size_t uRowID2, const void * __restrict pParam ) { const auto & tBinaryParam = *(const DistFuncParamBinary_t*)pParam; @@ -1211,11 +1211,21 @@ float IPBinaryFloatDistanceGeneric ( const void * pVect1, const void * pVect2, s return IPBinaryFloatDistance ( pVect1, pVect2, uRowID1, uRowID2, pParam ); } +float IPBinaryFloatDistanceGenericBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) +{ + return IPBinaryFloatDistance ( pVect1, pVect2, uRowID1, uRowID2, pParam ); +} + float L2BinaryFloatDistanceGeneric ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) { return L2BinaryFloatDistance ( pVect1, pVect2, uRowID1, uRowID2, pParam ); } +float L2BinaryFloatDistanceGenericBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) +{ + return L2BinaryFloatDistance ( pVect1, pVect2, uRowID1, uRowID2, pParam ); +} + #if !defined(USE_SIMDE) float IPBinaryFloatDistanceSIMD16 ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) @@ -1223,13 +1233,23 @@ float IPBinaryFloatDistanceSIMD16 ( const void * pVect1, const void * pVect2, si return IPBinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); } +float IPBinaryFloatDistanceSIMD16Build ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) +{ + return IPBinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); +} + float IPBinaryFloatDistanceSIMD16Residuals ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) { return IPBinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); } -template -static void IPBinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) +float IPBinaryFloatDistanceSIMD16ResidualsBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) +{ + return IPBinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); +} + +template +FORCE_INLINE static void IPBinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) { const auto & tBinaryParam = *(const DistFuncParamBinary_t*)pParam; @@ -1237,6 +1257,14 @@ static void IPBinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVec auto pVA = (const uint8_t *)pVect2A; auto pVB = (const uint8_t *)pVect2B; + // Build mode: source is 1-bit raw data; fetch its 4-bit representation from the pool + // using the source row id. Amortized over both candidates by doing this once. + if constexpr ( BUILD ) + { + if ( uRowID1!=(size_t)-1 ) + pV1 = tBinaryParam.m_fnFetcher(uRowID1); + } + assert ( uRowID2A!=(size_t)-1 ); assert ( uRowID2B!=(size_t)-1 ); @@ -1258,12 +1286,22 @@ static void IPBinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVec void IPBinaryFloatDistanceSIMD16Batch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) { - IPBinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + IPBinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); +} + +void IPBinaryFloatDistanceSIMD16Batch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) +{ + IPBinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); } void IPBinaryFloatDistanceSIMD16ResidualsBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) { - IPBinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + IPBinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); +} + +void IPBinaryFloatDistanceSIMD16ResidualsBatch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) +{ + IPBinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); } float L2BinaryFloatDistanceSIMD16 ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) @@ -1271,13 +1309,23 @@ float L2BinaryFloatDistanceSIMD16 ( const void * pVect1, const void * pVect2, si return L2BinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); } +float L2BinaryFloatDistanceSIMD16Build ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) +{ + return L2BinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); +} + float L2BinaryFloatDistanceSIMD16Residuals ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) { return L2BinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); } -template -static void L2BinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) +float L2BinaryFloatDistanceSIMD16ResidualsBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ) +{ + return L2BinaryFloatDistance> ( pVect1, pVect2, uRowID1, uRowID2, pParam ); +} + +template +static FORCE_INLINE void L2BinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) { const auto & tBinaryParam = *(const DistFuncParamBinary_t*)pParam; @@ -1285,6 +1333,12 @@ static void L2BinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVec auto pVA = (const uint8_t *)pVect2A; auto pVB = (const uint8_t *)pVect2B; + if constexpr ( BUILD ) + { + if ( uRowID1!=(size_t)-1 ) + pV1 = tBinaryParam.m_fnFetcher(uRowID1); + } + assert ( uRowID2A!=(size_t)-1 ); assert ( uRowID2B!=(size_t)-1 ); @@ -1306,12 +1360,22 @@ static void L2BinaryFloatDistanceBatch2 ( const void * pVect1, const void * pVec void L2BinaryFloatDistanceSIMD16Batch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) { - L2BinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + L2BinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); +} + +void L2BinaryFloatDistanceSIMD16Batch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) +{ + L2BinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); } void L2BinaryFloatDistanceSIMD16ResidualsBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) { - L2BinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); + L2BinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); +} + +void L2BinaryFloatDistanceSIMD16ResidualsBatch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ) +{ + L2BinaryFloatDistanceBatch2 ( pVect1, pVect2A, pVect2B, uRowID1, uRowID2A, uRowID2B, pParam, fDistA, fDistB ); } #endif // !USE_SIMDE @@ -1341,13 +1405,10 @@ IPSpaceBinaryFloat_c::IPSpaceBinaryFloat_c ( size_t uDim, bool bBuild ) , m_tDistFuncParam ( uDim ) { #if defined(USE_SIMDE) - if ( bBuild ) - m_fnDist = IPBinaryFloatDistance; - else - { - m_fnDist = IPBinaryFloatDistance; - m_eDistFuncId = DistFuncId_e::IP_BINARY_GENERIC; - } + m_fnDist = bBuild + ? IPBinaryFloatDistance + : IPBinaryFloatDistance; + m_eDistFuncId = DistFuncId_e::IP_BINARY_GENERIC; #else int iBytes = ( uDim+7 ) >> 3; bool bUseSSE = iBytes>=16; @@ -1367,13 +1428,10 @@ IPSpaceBinaryFloat_c::IPSpaceBinaryFloat_c ( size_t uDim, bool bBuild ) case 7: m_fnDist = IPBinaryFloatDistance>; break; } - if ( !bBuild ) - { - if ( bUseSSE ) - m_eDistFuncId = bNeedResiduals ? DistFuncId_e::IP_BINARY_SIMD16_RESIDUALS : DistFuncId_e::IP_BINARY_SIMD16; - else - m_eDistFuncId = DistFuncId_e::IP_BINARY_GENERIC; - } + if ( bUseSSE ) + m_eDistFuncId = bNeedResiduals ? DistFuncId_e::IP_BINARY_SIMD16_RESIDUALS : DistFuncId_e::IP_BINARY_SIMD16; + else + m_eDistFuncId = DistFuncId_e::IP_BINARY_GENERIC; #endif } @@ -1391,13 +1449,10 @@ L2SpaceBinaryFloat_c::L2SpaceBinaryFloat_c ( size_t uDim, bool bBuild ) , m_tDistFuncParam ( uDim ) { #if defined(USE_SIMDE) - if ( bBuild ) - m_fnDist = L2BinaryFloatDistance; - else - { - m_fnDist = L2BinaryFloatDistance; - m_eDistFuncId = DistFuncId_e::L2_BINARY_GENERIC; - } + m_fnDist = bBuild + ? L2BinaryFloatDistance + : L2BinaryFloatDistance; + m_eDistFuncId = DistFuncId_e::L2_BINARY_GENERIC; #else int iBytes = ( uDim+7 ) >> 3; bool bUseSSE = iBytes>=16; @@ -1417,13 +1472,10 @@ L2SpaceBinaryFloat_c::L2SpaceBinaryFloat_c ( size_t uDim, bool bBuild ) case 7: m_fnDist = L2BinaryFloatDistance>; break; } - if ( !bBuild ) - { - if ( bUseSSE ) - m_eDistFuncId = bNeedResiduals ? DistFuncId_e::L2_BINARY_SIMD16_RESIDUALS : DistFuncId_e::L2_BINARY_SIMD16; - else - m_eDistFuncId = DistFuncId_e::L2_BINARY_GENERIC; - } + if ( bUseSSE ) + m_eDistFuncId = bNeedResiduals ? DistFuncId_e::L2_BINARY_SIMD16_RESIDUALS : DistFuncId_e::L2_BINARY_SIMD16; + else + m_eDistFuncId = DistFuncId_e::L2_BINARY_GENERIC; #endif } diff --git a/knn/space.h b/knn/space.h index 7885e383..6c0d6eb5 100644 --- a/knn/space.h +++ b/knn/space.h @@ -181,18 +181,28 @@ struct DistFuncParamBinary_t float IPFloatDistance ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); void IPFloatDistanceBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); float IPBinaryFloatDistanceGeneric ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); +float IPBinaryFloatDistanceGenericBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); float L2FloatDistance ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); void L2FloatDistanceBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); float L2BinaryFloatDistanceGeneric ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); +float L2BinaryFloatDistanceGenericBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); #if !defined(USE_SIMDE) float IPBinaryFloatDistanceSIMD16 ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); +float IPBinaryFloatDistanceSIMD16Build ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); float IPBinaryFloatDistanceSIMD16Residuals ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); +float IPBinaryFloatDistanceSIMD16ResidualsBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); void IPBinaryFloatDistanceSIMD16Batch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); +void IPBinaryFloatDistanceSIMD16Batch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); void IPBinaryFloatDistanceSIMD16ResidualsBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); +void IPBinaryFloatDistanceSIMD16ResidualsBatch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); float L2BinaryFloatDistanceSIMD16 ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); +float L2BinaryFloatDistanceSIMD16Build ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); float L2BinaryFloatDistanceSIMD16Residuals ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); +float L2BinaryFloatDistanceSIMD16ResidualsBuild ( const void * pVect1, const void * pVect2, size_t uRowID1, size_t uRowID2, const void * pParam ); void L2BinaryFloatDistanceSIMD16Batch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); +void L2BinaryFloatDistanceSIMD16Batch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); void L2BinaryFloatDistanceSIMD16ResidualsBatch2 ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); +void L2BinaryFloatDistanceSIMD16ResidualsBatch2Build ( const void * pVect1, const void * pVect2A, const void * pVect2B, size_t uRowID1, size_t uRowID2A, size_t uRowID2B, const void * pParam, float & fDistA, float & fDistB ); #endif