Skip to content

Commit 522b26c

Browse files
committed
Join post-HNSW block rescore for diversifying child KNN (POC)
Optional blockRescore on DiversifyingChildren float/byte KNN; shared blockRescore() with visited accounting; tests; JMH benchmark; CHANGES (Improvements, GITHUB#15839). Relates to #15839
1 parent 29af851 commit 522b26c

7 files changed

Lines changed: 574 additions & 6 deletions

File tree

lucene/CHANGES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ New Features
108108

109109
Improvements
110110
---------------------
111+
* GITHUB#15839: DiversifyingChildren KNN queries now support optional post-HNSW block rescoring:
112+
when enabled, all children in each found parent's block are scored after approximate search,
113+
guaranteeing the best child per parent is returned and correctly tracking extra visited nodes.
114+
(Prithvi S)
115+
111116
* GITHUB#15704: Replace LinkedList with more efficient data structure. (Renato Haeberli)
112117

113118
* GITHUB#15682: Use ArrayDeque instead of LinkedList in CompoundWordTokenFilterBase.java. (Renato Haeberli)

lucene/benchmark-jmh/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ description = 'Lucene JMH micro-benchmarking module'
2020
dependencies {
2121
moduleImplementation project(':lucene:core')
2222
moduleImplementation project(':lucene:expressions')
23+
moduleImplementation project(':lucene:join')
2324
moduleImplementation project(':lucene:sandbox')
2425
moduleTestImplementation project(':lucene:test-framework')
2526

lucene/benchmark-jmh/src/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
requires jdk.unsupported;
2525
requires org.apache.lucene.core;
2626
requires org.apache.lucene.expressions;
27+
requires org.apache.lucene.join;
2728
requires org.apache.lucene.sandbox;
2829
requires commons.math3;
2930

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0 (the
6+
* "License"); you may not use this file except in compliance with the
7+
* License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.lucene.benchmark.jmh;
18+
19+
import java.io.IOException;
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
import java.util.Random;
23+
import java.util.concurrent.TimeUnit;
24+
import org.apache.lucene.document.Document;
25+
import org.apache.lucene.document.Field;
26+
import org.apache.lucene.document.KnnFloatVectorField;
27+
import org.apache.lucene.document.StringField;
28+
import org.apache.lucene.index.DirectoryReader;
29+
import org.apache.lucene.index.IndexWriter;
30+
import org.apache.lucene.index.IndexWriterConfig;
31+
import org.apache.lucene.index.Term;
32+
import org.apache.lucene.index.VectorSimilarityFunction;
33+
import org.apache.lucene.search.IndexSearcher;
34+
import org.apache.lucene.search.Query;
35+
import org.apache.lucene.search.TermQuery;
36+
import org.apache.lucene.search.TopDocs;
37+
import org.apache.lucene.search.join.BitSetProducer;
38+
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
39+
import org.apache.lucene.search.join.QueryBitSetProducer;
40+
import org.apache.lucene.store.ByteBuffersDirectory;
41+
import org.apache.lucene.store.Directory;
42+
import org.apache.lucene.util.VectorUtil;
43+
import org.openjdk.jmh.annotations.Benchmark;
44+
import org.openjdk.jmh.annotations.BenchmarkMode;
45+
import org.openjdk.jmh.annotations.Fork;
46+
import org.openjdk.jmh.annotations.Level;
47+
import org.openjdk.jmh.annotations.Measurement;
48+
import org.openjdk.jmh.annotations.Mode;
49+
import org.openjdk.jmh.annotations.OutputTimeUnit;
50+
import org.openjdk.jmh.annotations.Param;
51+
import org.openjdk.jmh.annotations.Scope;
52+
import org.openjdk.jmh.annotations.Setup;
53+
import org.openjdk.jmh.annotations.State;
54+
import org.openjdk.jmh.annotations.TearDown;
55+
import org.openjdk.jmh.annotations.Warmup;
56+
import org.openjdk.jmh.infra.Blackhole;
57+
58+
/**
59+
* End-to-end {@link DiversifyingChildrenFloatKnnVectorQuery} search on a single-segment block-join
60+
* index (children + parent marker per block), using the default HNSW approximate path ({@code
61+
* childFilter == null}).
62+
*
63+
* <p>The {@code blockRescore} parameter switches the feature on/off so both modes can be compared
64+
* in a single run (see <a href="https://github.com/apache/lucene/issues/15839">LUCENE-15839</a>).
65+
* Extra work scales roughly with {@code topK * childrenPerParent}.
66+
*
67+
* <p>Indicative results on Apple M-series, JDK 25, dim=96, topK=64, 4096 parent blocks (lower is
68+
* better):
69+
*
70+
* <pre>
71+
* blockRescore childrenPerParent Score (ms/op)
72+
* false 8 0.123
73+
* false 32 0.226
74+
* false 64 0.254
75+
* true 8 0.151 (+23%)
76+
* true 32 0.316 (+40%)
77+
* true 64 0.412 (+62%)
78+
* </pre>
79+
*
80+
* <p>Example:
81+
*
82+
* <pre>{@code
83+
* ./gradlew :lucene:benchmark-jmh:assemble
84+
* cd lucene/benchmark-jmh/build/benchmarks
85+
* java -jar lucene-benchmark-jmh-*-SNAPSHOT.jar DiversifyingChildrenFloatKnnJoin \\
86+
* -f 2 -wi 3 -i 8 -tu ms
87+
* }</pre>
88+
*/
89+
@BenchmarkMode(Mode.AverageTime)
90+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
91+
@State(Scope.Benchmark)
92+
@Warmup(iterations = 3, time = 1)
93+
@Measurement(iterations = 6, time = 1)
94+
@Fork(
95+
value = 2,
96+
jvmArgsAppend = {
97+
"-Xmx2g",
98+
"-Xms2g",
99+
"-XX:+AlwaysPreTouch",
100+
"--add-modules=jdk.incubator.vector"
101+
})
102+
public class DiversifyingChildrenFloatKnnJoinBenchmark {
103+
104+
/** Approximate neighbors per diversified parent bucket. */
105+
@Param({"64"})
106+
public int topK;
107+
108+
/**
109+
* Children with vectors per parent block. Post-HNSW block rescoring iterates sibling children in
110+
* each retained block, so incremental cost rises with this parameter.
111+
*/
112+
@Param({"8", "32", "64"})
113+
public int childrenPerParent;
114+
115+
@Param({"96"})
116+
public int dimension;
117+
118+
/**
119+
* Whether to enable post-HNSW block rescoring. When {@code true}, after HNSW search all children
120+
* in each found parent's block are scored to guarantee the best child is returned. Compare
121+
* {@code false} (baseline / no rescoring) against {@code true} (rescoring enabled) to measure
122+
* latency overhead.
123+
*/
124+
@Param({"false", "true"})
125+
public boolean blockRescore;
126+
127+
private Directory directory;
128+
private IndexSearcher searcher;
129+
private Query diversifyingJoinQuery;
130+
131+
static Document parentDoc() {
132+
Document d = new Document();
133+
d.add(new StringField("docType", "_parent", Field.Store.NO));
134+
return d;
135+
}
136+
137+
/** Fixed corpus size for stable HNSW behavior; must be >= topK. */
138+
private static final int NUM_PARENT_BLOCKS = 4096;
139+
140+
private static float[] randomUnitVector(Random random, int dim, float[] scratch) {
141+
for (int i = 0; i < dim; i++) {
142+
scratch[i] = random.nextFloat() * 2f - 1f;
143+
}
144+
return VectorUtil.l2normalize(scratch, false);
145+
}
146+
147+
@Setup(Level.Trial)
148+
public void setupTrial() throws IOException {
149+
if (topK > NUM_PARENT_BLOCKS) {
150+
throw new IllegalStateException("topK must be <= NUM_PARENT_BLOCKS");
151+
}
152+
directory = new ByteBuffersDirectory();
153+
IndexWriterConfig iwc = new IndexWriterConfig();
154+
long randomSeed = 0xC0FFEE42F00DL ^ ((long) childrenPerParent << 32) ^ dimension;
155+
Random random = new Random(randomSeed);
156+
float[] scratch = new float[dimension];
157+
try (IndexWriter w = new IndexWriter(directory, iwc)) {
158+
for (int p = 0; p < NUM_PARENT_BLOCKS; p++) {
159+
List<Document> block = new ArrayList<>(childrenPerParent + 1);
160+
for (int c = 0; c < childrenPerParent; c++) {
161+
Document child = new Document();
162+
child.add(
163+
new KnnFloatVectorField(
164+
"vec",
165+
randomUnitVector(random, dimension, scratch),
166+
VectorSimilarityFunction.DOT_PRODUCT));
167+
block.add(child);
168+
}
169+
block.add(parentDoc());
170+
w.addDocuments(block);
171+
}
172+
w.forceMerge(1);
173+
}
174+
175+
var reader = DirectoryReader.open(directory);
176+
searcher = new IndexSearcher(reader);
177+
BitSetProducer parentsFilter =
178+
new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent")));
179+
float[] queryVector = new float[dimension];
180+
queryVector[0] = 1f;
181+
for (int i = 1; i < dimension; i++) {
182+
queryVector[i] = 0f;
183+
}
184+
VectorUtil.l2normalize(queryVector, false);
185+
diversifyingJoinQuery =
186+
new DiversifyingChildrenFloatKnnVectorQuery(
187+
"vec",
188+
queryVector,
189+
null,
190+
topK,
191+
parentsFilter,
192+
org.apache.lucene.search.knn.KnnSearchStrategy.Hnsw.DEFAULT,
193+
blockRescore);
194+
}
195+
196+
@TearDown(Level.Trial)
197+
public void tearDownTrial() throws IOException {
198+
if (searcher != null) {
199+
searcher.getIndexReader().close();
200+
}
201+
if (directory != null) {
202+
directory.close();
203+
}
204+
}
205+
206+
@Benchmark
207+
public void searchDiversifyingJoinHnsw(Blackhole bh) throws IOException {
208+
TopDocs hits = searcher.search(diversifyingJoinQuery, topK);
209+
bh.consume(hits.scoreDocs.length);
210+
bh.consume(hits.totalHits.value());
211+
if (hits.scoreDocs.length > 0) {
212+
bh.consume(hits.scoreDocs[0].doc);
213+
bh.consume(hits.scoreDocs[0].score);
214+
}
215+
}
216+
}

lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
6060
private final Query childFilter;
6161
private final int k;
6262
private final byte[] query;
63+
private final boolean blockRescore;
6364

6465
/**
6566
* Create a ToParentBlockJoinByteVectorQuery.
@@ -72,7 +73,7 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery {
7273
*/
7374
public DiversifyingChildrenByteKnnVectorQuery(
7475
String field, byte[] query, Query childFilter, int k, BitSetProducer parentsFilter) {
75-
this(field, query, childFilter, k, parentsFilter, DEFAULT);
76+
this(field, query, childFilter, k, parentsFilter, DEFAULT, false);
7677
}
7778

7879
/**
@@ -95,11 +96,40 @@ public DiversifyingChildrenByteKnnVectorQuery(
9596
int k,
9697
BitSetProducer parentsFilter,
9798
KnnSearchStrategy searchStrategy) {
99+
this(field, query, childFilter, k, parentsFilter, searchStrategy, false);
100+
}
101+
102+
/**
103+
* Create a DiversifyingChildrenByteKnnVectorQuery with optional post-HNSW block rescoring.
104+
*
105+
* <p>When {@code blockRescore} is {@code true}, after the approximate HNSW search completes, all
106+
* children in each found parent's block are scored to guarantee the truly best child is returned.
107+
* See {@link DiversifyingChildrenFloatKnnVectorQuery#DiversifyingChildrenFloatKnnVectorQuery(
108+
* String, float[], Query, int, BitSetProducer, KnnSearchStrategy, boolean)} for details.
109+
*
110+
* @param field the query field
111+
* @param query the vector query
112+
* @param childFilter the child filter
113+
* @param k how many parent documents to return given the matching children
114+
* @param parentsFilter Filter identifying the parent documents.
115+
* @param searchStrategy the search strategy to use.
116+
* @param blockRescore if {@code true}, enables post-HNSW block rescoring.
117+
* @lucene.experimental
118+
*/
119+
public DiversifyingChildrenByteKnnVectorQuery(
120+
String field,
121+
byte[] query,
122+
Query childFilter,
123+
int k,
124+
BitSetProducer parentsFilter,
125+
KnnSearchStrategy searchStrategy,
126+
boolean blockRescore) {
98127
super(field, query, k, childFilter, searchStrategy);
99128
this.childFilter = childFilter;
100129
this.parentsFilter = parentsFilter;
101130
this.k = k;
102131
this.query = query;
132+
this.blockRescore = blockRescore;
103133
}
104134

105135
@Override
@@ -173,7 +203,25 @@ protected TopDocs approximateSearch(
173203
return NO_RESULTS;
174204
}
175205
context.reader().searchNearestVectors(field, query, collector, acceptDocs);
176-
return collector.topDocs();
206+
TopDocs results = collector.topDocs();
207+
if (!blockRescore || results.scoreDocs.length == 0) {
208+
return results;
209+
}
210+
BitSet parentBitSet = parentsFilter.getBitSet(context);
211+
if (parentBitSet == null) {
212+
return results;
213+
}
214+
ByteVectorValues vectorValues = context.reader().getByteVectorValues(field);
215+
if (vectorValues == null) {
216+
return results;
217+
}
218+
VectorScorer scorer = vectorValues.scorer(query);
219+
if (scorer == null) {
220+
return results;
221+
}
222+
// Delegate to the shared static implementation in the float variant.
223+
return DiversifyingChildrenFloatKnnVectorQuery.blockRescore(
224+
results, acceptDocs, parentBitSet, scorer);
177225
}
178226

179227
@Override
@@ -195,14 +243,15 @@ public boolean equals(Object o) {
195243
if (!super.equals(o)) return false;
196244
DiversifyingChildrenByteKnnVectorQuery that = (DiversifyingChildrenByteKnnVectorQuery) o;
197245
return k == that.k
246+
&& blockRescore == that.blockRescore
198247
&& Objects.equals(parentsFilter, that.parentsFilter)
199248
&& Objects.equals(childFilter, that.childFilter)
200249
&& Arrays.equals(query, that.query);
201250
}
202251

203252
@Override
204253
public int hashCode() {
205-
int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k);
254+
int result = Objects.hash(super.hashCode(), parentsFilter, childFilter, k, blockRescore);
206255
result = 31 * result + Arrays.hashCode(query);
207256
return result;
208257
}

0 commit comments

Comments
 (0)