Skip to content

Commit 0b37507

Browse files
authored
fix: fixed a crash that could occur when running knn queries with 1-bit quantized vectors from multiple threads
1 parent 2291fd5 commit 0b37507

2 files changed

Lines changed: 41 additions & 41 deletions

File tree

knn/quantizer.cpp

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -223,30 +223,27 @@ class BinaryQuantizer_c
223223
public:
224224
BinaryQuantizer_c ( int iDim, HNSWSimilarity_e eSimilarity );
225225

226-
void Quantize1Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult );
227-
void Quantize4Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult );
226+
void Quantize1Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult ) const;
227+
void Quantize4Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult ) const;
228228

229229
private:
230230
size_t m_uDim = 0;
231231
size_t m_uDimPadded = 0;
232232
HNSWSimilarity_e m_eSimilarity = HNSWSimilarity_e::COSINE;
233233
float m_fSqrtDim = 0.0f;
234234

235-
SpanResizeable_T<float> m_dVecMinusCentroid;
236-
SpanResizeable_T<uint8_t> m_dQuantized;
237-
238235
static void Pack ( const Span_T<float> & dVector, Span_T<uint8_t> & dPacked );
239-
FORCE_INLINE static int Quantize ( const Span_T<float> & dVector, float fMin, float fRange, SpanResizeable_T<uint8_t> & dQuantized );
236+
FORCE_INLINE static int Quantize ( const Span_T<float> & dVector, float fMin, float fRange, std::vector<uint8_t> & dQuantized );
240237
#if defined(USE_AVX2) || defined(USE_AVX512)
241238
FORCE_INLINE static void TransposeAVX ( const Span_T<uint8_t> & dQuantized, size_t uDim, Span_T<uint8_t> & dTransposed );
242239
#endif
243240
FORCE_INLINE static void Transpose ( const Span_T<uint8_t> & dQuantized, size_t uDim, Span_T<uint8_t> & dTransposed );
244241

245242
float ComputeQuality ( int iOriginalLength, const Span_T<float> & dVecMinusCentroidNormalized, const Span_T<uint8_t> & dPacked ) const;
246-
float QuantizeVecL2 ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult );
247-
Binary1BitFactorsIP_t QuantizeVecIP ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult );
243+
float QuantizeVecL2 ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult ) const;
244+
Binary1BitFactorsIP_t QuantizeVecIP ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult ) const;
248245

249-
template <typename T> FORCE_INLINE void PadToDim ( T & dVec )
246+
template <typename T> FORCE_INLINE void PadToDim ( T & dVec ) const
250247
{
251248
if ( dVec.size() < m_uDimPadded )
252249
dVec.resize ( m_uDimPadded, 0 );
@@ -303,7 +300,7 @@ void BinaryQuantizer_c::Pack ( const Span_T<float> & dVector, Span_T<uint8_t> &
303300
}
304301

305302

306-
int BinaryQuantizer_c::Quantize ( const Span_T<float> & dVector, float fMin, float fRange, SpanResizeable_T<uint8_t> & dQuantized )
303+
int BinaryQuantizer_c::Quantize ( const Span_T<float> & dVector, float fMin, float fRange, std::vector<uint8_t> & dQuantized )
307304
{
308305
dQuantized.resize ( dVector.size() );
309306

@@ -370,49 +367,49 @@ float BinaryQuantizer_c::ComputeQuality ( int iOriginalLength, const Span_T<floa
370367
}
371368

372369

373-
float BinaryQuantizer_c::QuantizeVecL2 ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult )
370+
float BinaryQuantizer_c::QuantizeVecL2 ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult ) const
374371
{
375-
m_dVecMinusCentroid.resize ( dVector.size() );
376-
for ( size_t i = 0; i < m_dVecMinusCentroid.size(); i++ )
377-
m_dVecMinusCentroid[i] = dVector[i] - dCentroid[i];
372+
std::vector<float> dVecMinusCentroid ( dVector.size() );
373+
for ( size_t i = 0; i < dVecMinusCentroid.size(); i++ )
374+
dVecMinusCentroid[i] = dVector[i] - dCentroid[i];
378375

379-
float fNorm = VecCalcNorm(m_dVecMinusCentroid);
380-
PadToDim(m_dVecMinusCentroid);
381-
Pack ( { m_dVecMinusCentroid.data(), dVector.size() }, dResult );
382-
m_dVecMinusCentroid.resize ( dVector.size() );
376+
float fNorm = VecCalcNorm(dVecMinusCentroid);
377+
PadToDim(dVecMinusCentroid);
378+
Pack ( { dVecMinusCentroid.data(), dVector.size() }, dResult );
379+
dVecMinusCentroid.resize ( dVector.size() );
383380

384-
for ( float & i : m_dVecMinusCentroid )
381+
for ( float & i : dVecMinusCentroid )
385382
i = std::abs(i) / m_fSqrtDim;
386383

387-
float fNormalized = std::accumulate ( m_dVecMinusCentroid.begin (), m_dVecMinusCentroid.end (), 0.0f );
384+
float fNormalized = std::accumulate ( dVecMinusCentroid.begin (), dVecMinusCentroid.end (), 0.0f );
388385
fNormalized /= fNorm;
389386
return std::isfinite(fNormalized) ? fNormalized : 0.8f;
390387
}
391388

392389

393-
Binary1BitFactorsIP_t BinaryQuantizer_c::QuantizeVecIP ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult )
390+
Binary1BitFactorsIP_t BinaryQuantizer_c::QuantizeVecIP ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, Span_T<uint8_t> & dResult ) const
394391
{
395392
float fVecDotCentroid = 0.0f;
396-
m_dVecMinusCentroid.resize ( dVector.size() );
393+
std::vector<float> dVecMinusCentroid ( dVector.size() );
397394
for ( size_t i = 0; i < dVector.size(); i++ )
398395
{
399396
fVecDotCentroid += dVector[i]*dCentroid[i];
400-
m_dVecMinusCentroid[i] = dVector[i] - dCentroid[i];
397+
dVecMinusCentroid[i] = dVector[i] - dCentroid[i];
401398
}
402399

403-
float fVecMinusCentroidNorm = VecCalcNorm(m_dVecMinusCentroid);
404-
PadToDim(m_dVecMinusCentroid);
405-
Pack ( { m_dVecMinusCentroid.data(), dVector.size() }, dResult );
400+
float fVecMinusCentroidNorm = VecCalcNorm(dVecMinusCentroid);
401+
PadToDim(dVecMinusCentroid);
402+
Pack ( { dVecMinusCentroid.data(), dVector.size() }, dResult );
406403

407-
for ( float & i : m_dVecMinusCentroid )
404+
for ( float & i : dVecMinusCentroid )
408405
i /= fVecMinusCentroidNorm;
409406

410-
float fQuality = ComputeQuality ( dVector.size(), m_dVecMinusCentroid, dResult );
407+
float fQuality = ComputeQuality ( dVector.size(), dVecMinusCentroid, dResult );
411408
return { fQuality, fVecMinusCentroidNorm, fVecDotCentroid, (float)PopCnt(dResult) };
412409
}
413410

414411

415-
void BinaryQuantizer_c::Quantize1Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult )
412+
void BinaryQuantizer_c::Quantize1Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult ) const
416413
{
417414
size_t uDataSize = ( ( dVector.size()+7 ) >> 3 );
418415
size_t uHeaderSize = m_eSimilarity==HNSWSimilarity_e::L2 ? sizeof(Binary1BitFactorsL2_t) : sizeof(Binary1BitFactorsIP_t);
@@ -588,33 +585,34 @@ void BinaryQuantizer_c::Transpose ( const Span_T<uint8_t> & dQuantized, size_t u
588585
}
589586

590587

591-
void BinaryQuantizer_c::Quantize4Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult )
588+
void BinaryQuantizer_c::Quantize4Bit ( const Span_T<float> & dVector, const std::vector<float> & dCentroid, std::vector<uint8_t> & dResult ) const
592589
{
593590
assert ( dVector.size()==dCentroid.size() );
594591

595-
m_dVecMinusCentroid.resize ( dVector.size() );
592+
std::vector<float> dVecMinusCentroid ( dVector.size() );
596593

597594
Binary4BitFactors_t tFactors = { 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f };
598595

599596
for ( size_t i = 0; i < dVector.size(); i++ )
600597
{
601598
float fDiff = dVector[i] - dCentroid[i];
602599
tFactors.m_fDistanceToCentroidSq += fDiff*fDiff;
603-
m_dVecMinusCentroid[i] = fDiff;
600+
dVecMinusCentroid[i] = fDiff;
604601
}
605602

606603
if ( m_eSimilarity!=HNSWSimilarity_e::L2 )
607604
{
608-
tFactors.m_fVecMinusCentroidNorm = VecNormalize(m_dVecMinusCentroid);
605+
tFactors.m_fVecMinusCentroidNorm = VecNormalize(dVecMinusCentroid);
609606
tFactors.m_fVecDotCentroid = VecDot ( dVector, dCentroid );
610607
}
611608

612609
float fMax;
613-
VecMinMax ( m_dVecMinusCentroid, tFactors.m_fMin, fMax );
610+
VecMinMax ( dVecMinusCentroid, tFactors.m_fMin, fMax );
614611
tFactors.m_fRange = ( fMax - tFactors.m_fMin ) / 15.0f;
615612

616-
tFactors.m_fQuantizedSum = (float)Quantize ( m_dVecMinusCentroid, tFactors.m_fMin, tFactors.m_fRange, m_dQuantized );
617-
PadToDim(m_dQuantized);
613+
std::vector<uint8_t> dQuantized;
614+
tFactors.m_fQuantizedSum = (float)Quantize ( dVecMinusCentroid, tFactors.m_fMin, tFactors.m_fRange, dQuantized );
615+
PadToDim(dQuantized);
618616

619617
size_t uDataSize = dVector.size() >> 1;
620618
size_t uHeaderSize = sizeof(float)*6;
@@ -626,13 +624,13 @@ void BinaryQuantizer_c::Quantize4Bit ( const Span_T<float> & dVector, const std:
626624
Span_T<uint8_t> dTransposed ( (uint8_t*)pHeader, uDataSize );
627625

628626
if ( uDataSize & 15 )
629-
Transpose ( m_dQuantized, m_uDim, dTransposed );
627+
Transpose ( dQuantized, m_uDim, dTransposed );
630628
else
631629
{
632630
#if defined(USE_AVX2) || defined(USE_AVX512)
633-
TransposeAVX ( m_dQuantized, m_uDim, dTransposed );
631+
TransposeAVX ( dQuantized, m_uDim, dTransposed );
634632
#else
635-
Transpose ( m_dQuantized, m_uDim, dTransposed );
633+
Transpose ( dQuantized, m_uDim, dTransposed );
636634
#endif
637635
}
638636
}

util/util_private.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,8 @@ int CalcNumBits ( uint64_t uNumber );
385385
bool CopySingleFile ( const std::string & sSource, const std::string & sDest, std::string & sError, int iMode, size_t tBufferSize=1048576 );
386386
bool FloatEqual ( float fA, float fB );
387387

388-
FORCE_INLINE float VecCalcNorm ( const Span_T<float> & dData )
388+
template <typename T>
389+
FORCE_INLINE float VecCalcNorm ( const T & dData )
389390
{
390391
size_t uSize = dData.size();
391392
size_t i = 0;
@@ -417,7 +418,8 @@ FORCE_INLINE float VecCalcNorm ( const Span_T<float> & dData )
417418
return sqrtf(fNorm);
418419
}
419420

420-
FORCE_INLINE float VecNormalize ( Span_T<float> & dData )
421+
template <typename T>
422+
FORCE_INLINE float VecNormalize ( T & dData )
421423
{
422424
float fNorm = VecCalcNorm(dData);
423425
float fDiv = 1.0f / (fNorm + 1e-30f);

0 commit comments

Comments
 (0)