Skip to content

Commit d9e3698

Browse files
committed
fix(knn): reapply sort after rescoring distances
Related issue #4320 KNN rescore now copies exact distances into the original knn_dist() sort key and reruns the original comparator, so explicit ORDER BY tie-breakers are respected for equal rescored distances.
1 parent bb1df6f commit d9e3698

8 files changed

Lines changed: 123 additions & 71 deletions

File tree

src/knnmisc.cpp

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <algorithm>
1212
#include "knnmisc.h"
13+
#include "sortsetup.h"
14+
#include "sortcomp.h"
1315
#include "knnlib.h"
1416
#include "exprtraits.h"
1517
#include "sphinxint.h"
@@ -841,24 +843,48 @@ RowIteratorsWithEstimates_t CreateKNNIterators ( knn::KNN_i * pKNN, const CSphQu
841843

842844
///////////////////////////////////////////////////////////////////////////////
843845

844-
struct MatchSortRescore_fn : CSphMatchComparatorState
846+
struct MatchSortRescore_fn
845847
{
846-
const CSphAttrLocator & m_tLocator;
848+
const ISphMatchComparator * m_pComp = nullptr;
849+
const CSphMatchComparatorState & m_tState;
847850

848-
MatchSortRescore_fn ( const CSphAttrLocator & tLoc ) : m_tLocator(tLoc) {}
851+
MatchSortRescore_fn ( const ISphMatchComparator * pComp, const CSphMatchComparatorState & tState )
852+
: m_pComp ( pComp )
853+
, m_tState ( tState )
854+
{
855+
assert ( m_pComp );
856+
}
849857

850858
bool IsLess ( const CSphMatch * a, const CSphMatch * b ) const
851859
{
852860
assert ( a && b );
853-
return a->GetAttrFloat(m_tLocator) < b->GetAttrFloat(m_tLocator);
861+
// CSphMatchComparatorState comparators report whether a match is worse.
862+
// sphSort() needs the opposite: whether a match must be emitted earlier.
863+
return m_pComp->VirtualIsLess ( *b, *a, m_tState );
854864
}
855865
};
856866

867+
static ISphMatchComparator * CreateMatchComparator ( ESphSortFunc eFunc )
868+
{
869+
switch ( eFunc )
870+
{
871+
case FUNC_REL_DESC: return new MatchRelevanceLt_fn();
872+
case FUNC_TIMESEGS: return new MatchTimeSegments_fn();
873+
case FUNC_GENERIC1: return new MatchGeneric1_fn();
874+
case FUNC_GENERIC2: return new MatchGeneric2_fn();
875+
case FUNC_GENERIC3: return new MatchGeneric3_fn();
876+
case FUNC_GENERIC4: return new MatchGeneric4_fn();
877+
case FUNC_GENERIC5: return new MatchGeneric5_fn();
878+
case FUNC_EXPR: return new MatchExpr_fn();
879+
default: return nullptr;
880+
}
881+
}
882+
857883

858884
class RescoreSorter_c : public ISphMatchSorter
859885
{
860886
public:
861-
RescoreSorter_c ( ISphMatchSorter * pSorter ) : m_pSorter ( pSorter ) {}
887+
RescoreSorter_c ( ISphMatchSorter * pSorter, CSphRefcountedPtr<ISphMatchComparator> pComp ) : m_pSorter ( pSorter ), m_pComp ( std::move ( pComp ) ) {}
862888

863889
bool Push ( const CSphMatch & tEntry ) final { return m_pSorter->Push(tEntry); }
864890
void Push ( const VecTraits_T<const CSphMatch> & dMatches ) override { for ( auto & i : dMatches ) m_pSorter->Push(i); }
@@ -897,6 +923,7 @@ class RescoreSorter_c : public ISphMatchSorter
897923

898924
private:
899925
std::unique_ptr<ISphMatchSorter> m_pSorter;
926+
CSphRefcountedPtr<ISphMatchComparator> m_pComp;
900927
};
901928

902929

@@ -921,14 +948,18 @@ int RescoreSorter_c::Flatten ( CSphMatch * pTo )
921948
auto * pKNNDistRescore = m_pSorter->GetSchema()->GetAttr ( GetKnnDistRescoreAttrName() );
922949
assert(pKNNDistRescore);
923950

924-
MatchSortRescore_fn tRescore ( pKNNDistRescore->m_tLocator );
925-
sphSort ( dMatches.Begin(), dMatches.GetLength(), tRescore, MatchSortAccessor_t() );
926-
927-
// copy rescored dist to old dist
951+
// Copy rescored dist to old dist first, then re-apply the original sorter.
952+
// The original sorter state starts with knn_dist() (unless the user explicitly
953+
// sorted by knn_dist()) and then contains user ORDER BY tie-breakers.
954+
// Sorting by rescored distance only, even stably, can keep stale approximate
955+
// distance order ahead of explicit tie-breakers for exact-distance ties.
928956
for ( auto & tMatch : dMatches )
929957
for ( const auto & tLocator : dOldKnnDistLoc )
930958
tMatch.SetAttrFloat ( tLocator, tMatch.GetAttrFloat ( pKNNDistRescore->m_tLocator ) );
931959

960+
MatchSortRescore_fn tRescore ( m_pComp, m_pSorter->GetState() );
961+
sphSort ( dMatches.Begin(), dMatches.GetLength(), tRescore, MatchSortAccessor_t() );
962+
932963
for ( auto & i : dMatches )
933964
Swap ( i, *pTo++ );
934965

@@ -938,7 +969,7 @@ int RescoreSorter_c::Flatten ( CSphMatch * pTo )
938969

939970
ISphMatchSorter * RescoreSorter_c::Clone() const
940971
{
941-
auto pClone = new RescoreSorter_c ( m_pSorter->Clone() );
972+
auto pClone = new RescoreSorter_c ( m_pSorter->Clone(), m_pComp );
942973
CloneTo(pClone);
943974
return pClone;
944975
}
@@ -952,12 +983,16 @@ void RescoreSorter_c::CloneTo ( ISphMatchSorter * pTrg ) const
952983
}
953984

954985

955-
ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings )
986+
ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings, ESphSortFunc eMatchFunc )
956987
{
957988
if ( tSettings.m_sAttr.IsEmpty() || !tSettings.m_bRescore )
958989
return pSorter;
959990

960-
return new RescoreSorter_c(pSorter);
991+
CSphRefcountedPtr<ISphMatchComparator> pComp ( CreateMatchComparator ( eMatchFunc ) );
992+
if ( !pComp )
993+
return nullptr;
994+
995+
return new RescoreSorter_c ( pSorter, std::move ( pComp ) );
961996
}
962997

963998
bool ValidateEmbeddingsAPITimeout ( const CSphString & sValue, int & iTimeout, CSphString & sError )

src/knnmisc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ std::pair<RowidIterator_i *, bool> CreateKNNIterator ( knn::KNN_i * pKNN, const
7777
RowIteratorsWithEstimates_t CreateKNNIterators ( knn::KNN_i * pKNN, const CSphQuery & tQuery, const ISphSchema & tIndexSchema, const ISphSchema & tSorterSchema, knn::KNNFilter_i * pFilter, knn::HNSWTerminationPolicy_e ePolicy, QueryProfile_c * pProfile, bool & bError, CSphString & sError );
7878
std::unique_ptr<knn::KNNFilter_i> CreateKNNPrefilter ( const CSphQueryContext & tCtx, const CSphRowitem * pAttrPool, int iStride, int iDynamicSize, int64_t iFilterCount );
7979

80-
ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings );
80+
ISphMatchSorter * CreateKNNRescoreSorter ( ISphMatchSorter * pSorter, const KnnSearchSettings_t & tSettings, ESphSortFunc eMatchFunc );
8181

8282
const char * GetAPITimeoutErrorMsg();
8383
bool ValidateEmbeddingsAPITimeout ( const CSphString & sValue, int & iTimeout, CSphString & sError );

src/queuecreator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2638,7 +2638,7 @@ ISphMatchSorter * QueueCreator_c::SpawnQueue()
26382638

26392639
if ( !m_tQuery.m_bHybridSearch )
26402640
{
2641-
pSorter = CreateKNNRescoreSorter ( pSorter, m_tQuery.HasKnn() ? m_tQuery.SingleKnnSettings() : KnnSearchSettings_t() );
2641+
pSorter = CreateKNNRescoreSorter ( pSorter, m_tQuery.HasKnn() ? m_tQuery.SingleKnnSettings() : KnnSearchSettings_t(), m_eMatchFunc );
26422642
if ( !pSorter )
26432643
return nullptr;
26442644
}

test/clt-tests/core/test-alter-rebuild-knn.rec

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ alter table t rebuild knn
105105
select *, knn_dist() dist from t where knn(v2, 10, (0.5,0.5,0.5)) order by dist asc, id asc
106106
--------------
107107
*************************** 1. row ***************************
108-
id: 3
109-
f:
110-
v1:
111-
v2: 1.000000,1.000000,1.000000
108+
id: 1
109+
f: abc
110+
v1: 0.000000,0.000000,0.000000
111+
v2: 0.000000,0.000000,0.000000
112112
@knn_dist: 0.750000
113113
dist: 0.750000
114114
*************************** 2. row ***************************
@@ -119,9 +119,9 @@ select *, knn_dist() dist from t where knn(v2, 10, (0.5,0.5,0.5)) order by dist
119119
@knn_dist: 0.750000
120120
dist: 0.750000
121121
*************************** 3. row ***************************
122-
id: 1
123-
f: abc
124-
v1: 0.000000,0.000000,0.000000
125-
v2: 0.000000,0.000000,0.000000
122+
id: 3
123+
f:
124+
v1:
125+
v2: 1.000000,1.000000,1.000000
126126
@knn_dist: 0.750000
127127
dist: 0.750000

0 commit comments

Comments
 (0)