Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions src/knnmisc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <algorithm>
#include "knnmisc.h"
#include "sortsetup.h"
#include "sortcomp.h"
#include "knnlib.h"
#include "exprtraits.h"
#include "sphinxint.h"
Expand Down Expand Up @@ -841,24 +843,48 @@ RowIteratorsWithEstimates_t CreateKNNIterators ( knn::KNN_i * pKNN, const CSphQu

///////////////////////////////////////////////////////////////////////////////

struct MatchSortRescore_fn : CSphMatchComparatorState
struct MatchSortRescore_fn
{
const CSphAttrLocator & m_tLocator;
const ISphMatchComparator * m_pComp = nullptr;
const CSphMatchComparatorState & m_tState;

MatchSortRescore_fn ( const CSphAttrLocator & tLoc ) : m_tLocator(tLoc) {}
MatchSortRescore_fn ( const ISphMatchComparator * pComp, const CSphMatchComparatorState & tState )
: m_pComp ( pComp )
, m_tState ( tState )
{
assert ( m_pComp );
}

bool IsLess ( const CSphMatch * a, const CSphMatch * b ) const
{
assert ( a && b );
return a->GetAttrFloat(m_tLocator) < b->GetAttrFloat(m_tLocator);
// CSphMatchComparatorState comparators report whether a match is worse.
// sphSort() needs the opposite: whether a match must be emitted earlier.
return m_pComp->VirtualIsLess ( *b, *a, m_tState );
}
};

static ISphMatchComparator * CreateMatchComparator ( ESphSortFunc eFunc )
{
switch ( eFunc )
{
case FUNC_REL_DESC: return new MatchRelevanceLt_fn();
case FUNC_TIMESEGS: return new MatchTimeSegments_fn();
case FUNC_GENERIC1: return new MatchGeneric1_fn();
case FUNC_GENERIC2: return new MatchGeneric2_fn();
case FUNC_GENERIC3: return new MatchGeneric3_fn();
case FUNC_GENERIC4: return new MatchGeneric4_fn();
case FUNC_GENERIC5: return new MatchGeneric5_fn();
case FUNC_EXPR: return new MatchExpr_fn();
default: return nullptr;
}
}


class RescoreSorter_c : public ISphMatchSorter
{
public:
RescoreSorter_c ( ISphMatchSorter * pSorter ) : m_pSorter ( pSorter ) {}
RescoreSorter_c ( ISphMatchSorter * pSorter, CSphRefcountedPtr<ISphMatchComparator> pComp ) : m_pSorter ( pSorter ), m_pComp ( std::move ( pComp ) ) {}

bool Push ( const CSphMatch & tEntry ) final { return m_pSorter->Push(tEntry); }
void Push ( const VecTraits_T<const CSphMatch> & dMatches ) override { for ( auto & i : dMatches ) m_pSorter->Push(i); }
Expand Down Expand Up @@ -897,6 +923,7 @@ class RescoreSorter_c : public ISphMatchSorter

private:
std::unique_ptr<ISphMatchSorter> m_pSorter;
CSphRefcountedPtr<ISphMatchComparator> m_pComp;
};


Expand All @@ -921,14 +948,18 @@ int RescoreSorter_c::Flatten ( CSphMatch * pTo )
auto * pKNNDistRescore = m_pSorter->GetSchema()->GetAttr ( GetKnnDistRescoreAttrName() );
assert(pKNNDistRescore);

MatchSortRescore_fn tRescore ( pKNNDistRescore->m_tLocator );
sphSort ( dMatches.Begin(), dMatches.GetLength(), tRescore, MatchSortAccessor_t() );

// copy rescored dist to old dist
// Copy rescored dist to old dist first, then re-apply the original sorter.
// The original sorter state starts with knn_dist() (unless the user explicitly
// sorted by knn_dist()) and then contains user ORDER BY tie-breakers.
// Sorting by rescored distance only, even stably, can keep stale approximate
// distance order ahead of explicit tie-breakers for exact-distance ties.
for ( auto & tMatch : dMatches )
for ( const auto & tLocator : dOldKnnDistLoc )
tMatch.SetAttrFloat ( tLocator, tMatch.GetAttrFloat ( pKNNDistRescore->m_tLocator ) );

MatchSortRescore_fn tRescore ( m_pComp, m_pSorter->GetState() );
sphSort ( dMatches.Begin(), dMatches.GetLength(), tRescore, MatchSortAccessor_t() );

for ( auto & i : dMatches )
Swap ( i, *pTo++ );

Expand All @@ -938,7 +969,7 @@ int RescoreSorter_c::Flatten ( CSphMatch * pTo )

ISphMatchSorter * RescoreSorter_c::Clone() const
{
auto pClone = new RescoreSorter_c ( m_pSorter->Clone() );
auto pClone = new RescoreSorter_c ( m_pSorter->Clone(), m_pComp );
CloneTo(pClone);
return pClone;
}
Expand All @@ -952,12 +983,16 @@ void RescoreSorter_c::CloneTo ( ISphMatchSorter * pTrg ) const
}


ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings )
ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings, ESphSortFunc eMatchFunc )
{
if ( tSettings.m_sAttr.IsEmpty() || !tSettings.m_bRescore )
return pSorter;

return new RescoreSorter_c(pSorter);
CSphRefcountedPtr<ISphMatchComparator> pComp ( CreateMatchComparator ( eMatchFunc ) );
if ( !pComp )
return nullptr;

return new RescoreSorter_c ( pSorter, std::move ( pComp ) );
}

bool ValidateEmbeddingsAPITimeout ( const CSphString & sValue, int & iTimeout, CSphString & sError )
Expand Down
2 changes: 1 addition & 1 deletion src/knnmisc.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::pair<RowidIterator_i *, bool> CreateKNNIterator ( knn::KNN_i * pKNN, const
RowIteratorsWithEstimates_t CreateKNNIterators ( knn::KNN_i * pKNN, const CSphQuery & tQuery, const ISphSchema & tIndexSchema, const ISphSchema & tSorterSchema, knn::KNNFilter_i * pFilter, knn::HNSWTerminationPolicy_e ePolicy, QueryProfile_c * pProfile, bool & bError, CSphString & sError );
std::unique_ptr<knn::KNNFilter_i> CreateKNNPrefilter ( const CSphQueryContext & tCtx, const CSphRowitem * pAttrPool, int iStride, int iDynamicSize, int64_t iFilterCount );

ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings );
ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings, ESphSortFunc eMatchFunc );

const char * GetAPITimeoutErrorMsg();
bool ValidateEmbeddingsAPITimeout ( const CSphString & sValue, int & iTimeout, CSphString & sError );
2 changes: 1 addition & 1 deletion src/queuecreator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2638,7 +2638,7 @@ ISphMatchSorter * QueueCreator_c::SpawnQueue()

if ( !m_tQuery.m_bHybridSearch )
{
pSorter = CreateKNNRescoreSorter ( pSorter, m_tQuery.HasKnn() ? m_tQuery.SingleKnnSettings() : KnnSearchSettings_t() );
pSorter = CreateKNNRescoreSorter ( pSorter, m_tQuery.HasKnn() ? m_tQuery.SingleKnnSettings() : KnnSearchSettings_t(), m_eMatchFunc );
if ( !pSorter )
return nullptr;
}
Expand Down
16 changes: 8 additions & 8 deletions test/clt-tests/core/test-alter-rebuild-knn.rec
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ alter table t rebuild knn
select *, knn_dist() dist from t where knn(v2, 10, (0.5,0.5,0.5)) order by dist asc, id asc
--------------
*************************** 1. row ***************************
id: 3
f:
v1:
v2: 1.000000,1.000000,1.000000
id: 1
f: abc
v1: 0.000000,0.000000,0.000000
v2: 0.000000,0.000000,0.000000
@knn_dist: 0.750000
dist: 0.750000
*************************** 2. row ***************************
Expand All @@ -119,9 +119,9 @@ select *, knn_dist() dist from t where knn(v2, 10, (0.5,0.5,0.5)) order by dist
@knn_dist: 0.750000
dist: 0.750000
*************************** 3. row ***************************
id: 1
f: abc
v1: 0.000000,0.000000,0.000000
v2: 0.000000,0.000000,0.000000
id: 3
f:
v1:
v2: 1.000000,1.000000,1.000000
@knn_dist: 0.750000
dist: 0.750000
Loading
Loading