2020
2121import java .io .IOException ;
2222import java .util .Arrays ;
23- import java .util .Comparator ;
2423import java .util .Objects ;
2524import org .apache .lucene .index .FloatVectorValues ;
2625import org .apache .lucene .index .LeafReaderContext ;
@@ -62,7 +61,7 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery
6261 private final Query childFilter ;
6362 private final int k ;
6463 private final float [] query ;
65- private final boolean blockRescore ;
64+ private final boolean rescoreBlocks ;
6665
6766 /**
6867 * Create a DiversifyingChildrenFloatKnnVectorQuery.
@@ -104,19 +103,23 @@ public DiversifyingChildrenFloatKnnVectorQuery(
104103 /**
105104 * Create a DiversifyingChildrenFloatKnnVectorQuery with optional post-HNSW block rescoring.
106105 *
107- * <p>When {@code blockRescore} is {@code true}, after the approximate HNSW search completes, all
108- * children in each found parent's block are scored to guarantee the truly best child is returned —
109- * not merely the sibling the graph traversal happened to reach first. This adds O(k ×
110- * childrenPerParent) extra scoring work; enable it when block sizes are small or result quality is
111- * more important than latency.
106+ * <p>When {@code rescoreBlocks} is {@code true}, after the approximate HNSW search completes, all
107+ * children in each found parent's block are scored to guarantee the truly best child is returned
108+ * — not merely the sibling the graph traversal happened to reach first. This adds O(k ×
109+ * childrenPerParent) extra scoring work; enable it when block sizes are small or result quality
110+ * is more important than latency.
111+ *
112+ * <p>This applies only to the approximate (HNSW) search path. When the index is small enough that
113+ * Lucene falls back to exact search, all children are already scored exhaustively and no
114+ * additional rescoring is performed.
112115 *
113116 * @param field the query field
114117 * @param query the vector query
115118 * @param childFilter the child filter
116119 * @param k how many parent documents to return given the matching children
117120 * @param parentsFilter Filter identifying the parent documents.
118121 * @param searchStrategy the search strategy to use.
119- * @param blockRescore if {@code true}, enables post-HNSW block rescoring.
122+ * @param rescoreBlocks if {@code true}, enables post-HNSW block rescoring.
120123 * @lucene.experimental
121124 */
122125 public DiversifyingChildrenFloatKnnVectorQuery (
@@ -126,13 +129,13 @@ public DiversifyingChildrenFloatKnnVectorQuery(
126129 int k ,
127130 BitSetProducer parentsFilter ,
128131 KnnSearchStrategy searchStrategy ,
129- boolean blockRescore ) {
132+ boolean rescoreBlocks ) {
130133 super (field , query , k , childFilter , searchStrategy );
131134 this .childFilter = childFilter ;
132135 this .parentsFilter = parentsFilter ;
133136 this .k = k ;
134137 this .query = query ;
135- this .blockRescore = blockRescore ;
138+ this .rescoreBlocks = rescoreBlocks ;
136139 }
137140
138141 @ Override
@@ -206,7 +209,7 @@ protected TopDocs approximateSearch(
206209 }
207210 context .reader ().searchNearestVectors (field , query , collector , acceptDocs );
208211 TopDocs results = collector .topDocs ();
209- if (!blockRescore || results .scoreDocs .length == 0 ) {
212+ if (!rescoreBlocks || results .scoreDocs .length == 0 ) {
210213 return results ;
211214 }
212215 BitSet parentBitSet = parentsFilter .getBitSet (context );
@@ -226,13 +229,10 @@ protected TopDocs approximateSearch(
226229
227230 /**
228231 * For each parent already found by approximate search, scores all children in that parent's block
229- * to ensure the truly best child is returned — not merely the sibling the graph traversal happened
230- * to reach first. Children are processed in ascending docId order so the sequential {@link
231- * VectorScorer} only advances forward. Extra nodes scored are added to {@link
232+ * to ensure the truly best child is returned — not merely the sibling the graph traversal
233+ * happened to reach first. Children are processed in ascending docId order so the sequential
234+ * {@link VectorScorer} only advances forward. Extra nodes scored are added to {@link
232235 * TotalHits#value()}.
233- *
234- * <p>This method is package-private so that {@link DiversifyingChildrenByteKnnVectorQuery} can
235- * reuse the same implementation rather than duplicating it.
236236 */
237237 static TopDocs blockRescore (
238238 TopDocs results , AcceptDocs acceptDocs , BitSet parentBitSet , VectorScorer scorer )
@@ -243,7 +243,7 @@ static TopDocs blockRescore(
243243 // Sort by docId so parent blocks are visited in ascending order — the forward-only
244244 // VectorScorer cannot go backwards.
245245 ScoreDoc [] scoreDocs = results .scoreDocs .clone ();
246- Arrays .sort (scoreDocs , Comparator . comparingInt ( sd -> sd .doc ));
246+ Arrays .sort (scoreDocs , ( a , b ) -> Integer . compare ( a . doc , b .doc ));
247247
248248 long extraVisited = 0 ;
249249 for (ScoreDoc scoreDoc : scoreDocs ) {
@@ -256,10 +256,11 @@ static TopDocs blockRescore(
256256 continue ;
257257 }
258258 if (scorerIter .advance (child ) == child ) {
259- // Don't double-count the child HNSW already visited.
260- if ( child != hnswBestChild ) {
261- extraVisited ++ ;
259+ if ( child == hnswBestChild ) {
260+ // Advance past the child HNSW already scored; no need to re-compute.
261+ continue ;
262262 }
263+ extraVisited ++;
263264 float s = scorer .score ();
264265 if (s > scoreDoc .score ) {
265266 scoreDoc .score = s ;
@@ -293,15 +294,15 @@ public boolean equals(Object o) {
293294 if (!super .equals (o )) return false ;
294295 DiversifyingChildrenFloatKnnVectorQuery that = (DiversifyingChildrenFloatKnnVectorQuery ) o ;
295296 return k == that .k
296- && blockRescore == that .blockRescore
297+ && rescoreBlocks == that .rescoreBlocks
297298 && Objects .equals (parentsFilter , that .parentsFilter )
298299 && Objects .equals (childFilter , that .childFilter )
299300 && Arrays .equals (query , that .query );
300301 }
301302
302303 @ Override
303304 public int hashCode () {
304- int result = Objects .hash (super .hashCode (), parentsFilter , childFilter , k , blockRescore );
305+ int result = Objects .hash (super .hashCode (), parentsFilter , childFilter , k , rescoreBlocks );
305306 result = 31 * result + Arrays .hashCode (query );
306307 return result ;
307308 }
0 commit comments