diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index eeae8f559f4b..70b6fd91d7c2 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -106,6 +106,10 @@ New Features * GITHUB#15818: Add BM25 k3 query-term frequency saturation to BM25Similarity. (Sagar Upadhyaya) +* GITHUB#16051: Add ArrayTermInSetQuery to sandbox, an alternative TermInSetQuery + storing terms in a sorted BytesRef[] over a packed byte[] for faster per-segment + iteration and a vectorized equals/hashCode fast path. (Govind Balaji S) + Improvements --------------------- * GITHUB#15704: Replace LinkedList with more efficient data structure. (Renato Haeberli) diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ArrayTermInSetQueryBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ArrayTermInSetQueryBenchmark.java new file mode 100644 index 000000000000..4174ebdfcfd1 --- /dev/null +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/ArrayTermInSetQueryBenchmark.java @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.benchmark.jmh; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Random; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.sandbox.search.ArrayTermInSetQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermInSetQuery; +import org.apache.lucene.store.ByteBuffersDirectory; +import org.apache.lucene.util.BytesRef; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +/** + * Compares {@link TermInSetQuery} (PrefixCodedTerms-backed) against {@link ArrayTermInSetQuery} + * (sorted-{@code BytesRef[]}-backed) along these dimensions: + * + * + * + *

Parameters: + * + *

+ */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@Warmup(iterations = 3, time = 5) +@Measurement(iterations = 5, time = 5) +@Fork(1) +public class ArrayTermInSetQueryBenchmark { + + private static final String FIELD = "allowed_groups"; + + enum IndexContent { + /** Index exactly the query terms — every query term matches in every segment. */ + QUERY_ONLY, + /** Index only 2 of the query terms per segment — most seeks are misses. */ + SPARSE, + /** Index 50k random terms independent of the query — large dictionary, zero matches. */ + RANDOM_50K + } + + enum InputShape { + /** {@code Arrays.asList(shuffledArray)} — both queries radix-sort internally. */ + UNSORTED_LIST, + /** + * {@code TreeSet} with natural-order comparator — both queries hit the skip-sort fast + * path. Isolates the storage-shape cost from the sort cost. + */ + SORTED_SET + } + + @State(Scope.Benchmark) + public static class BenchState { + + @Param({"30", "300", "3000", "30000"}) + public int numTerms; + + @Param({"5", "20", "50"}) + public int numSegments; + + @Param({"QUERY_ONLY", "SPARSE", "RANDOM_50K"}) + public String indexContent; + + @Param({"UNSORTED_LIST", "SORTED_SET"}) + public String inputShape; + + Collection termsInput; + IndexSearcher searcher; + + private DirectoryReader reader; + private ByteBuffersDirectory directory; + + @Setup(Level.Trial) + public void setup() throws IOException { + Random rng = new Random(42); + BytesRef[] sortedTerms = generateSortedTerms(rng, numTerms); + + // Build the input collection in the requested shape. + InputShape shape = InputShape.valueOf(inputShape); + switch (shape) { + case UNSORTED_LIST: + { + List shuffled = new ArrayList<>(Arrays.asList(sortedTerms)); + Collections.shuffle(shuffled, new Random(rng.nextLong())); + termsInput = shuffled; + break; + } + case SORTED_SET: + { + // TreeSet with no explicit comparator → natural-order; both query ctors will skip + // the radix sort. + TreeSet set = new TreeSet<>(); + Collections.addAll(set, sortedTerms); + termsInput = set; + break; + } + } + + IndexContent content = IndexContent.valueOf(indexContent); + // Two deterministic entries from the query set — first and middle term. + BytesRef[] sparseTerms = new BytesRef[] {sortedTerms[0], sortedTerms[numTerms / 2]}; + // Extra terms pre-generated only when needed. + BytesRef[] extraTerms = + content == IndexContent.RANDOM_50K ? generateSortedTerms(rng, 50_000) : null; + + directory = new ByteBuffersDirectory(); + IndexWriter writer = + new IndexWriter( + directory, new IndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE)); + + for (int seg = 0; seg < numSegments; seg++) { + switch (content) { + case QUERY_ONLY: + for (BytesRef term : sortedTerms) { + addDoc(writer, term.utf8ToString()); + } + break; + case SPARSE: + for (BytesRef term : sparseTerms) { + addDoc(writer, term.utf8ToString()); + } + break; + case RANDOM_50K: + for (BytesRef term : extraTerms) { + addDoc(writer, term.utf8ToString()); + } + break; + } + writer.commit(); + } + + reader = DirectoryReader.open(writer); + writer.close(); + searcher = new IndexSearcher(reader); + searcher.setQueryCache(null); + } + + @TearDown(Level.Trial) + public void tearDown() throws IOException { + reader.close(); + directory.close(); + } + + private static BytesRef[] generateSortedTerms(Random rng, int count) { + BytesRef[] terms = new BytesRef[count]; + for (int i = 0; i < count; i++) { + terms[i] = new BytesRef(String.format(Locale.ROOT, "%016x", rng.nextLong())); + } + Arrays.sort(terms); + return terms; + } + + private static void addDoc(IndexWriter writer, String value) throws IOException { + Document doc = new Document(); + doc.add(new StringField(FIELD, value, Field.Store.NO)); + writer.addDocument(doc); + } + } + + @Benchmark + public void constructTermInSetQuery(BenchState state, Blackhole bh) { + bh.consume(new TermInSetQuery(FIELD, state.termsInput)); + } + + @Benchmark + public void constructArrayTermInSetQuery(BenchState state, Blackhole bh) { + bh.consume(new ArrayTermInSetQuery(FIELD, state.termsInput)); + } + + @Benchmark + public void constructAndIterateTermInSetQuery(BenchState state, Blackhole bh) throws IOException { + bh.consume(state.searcher.count(new TermInSetQuery(FIELD, state.termsInput))); + } + + @Benchmark + public void constructAndIterateArrayTermInSetQuery(BenchState state, Blackhole bh) + throws IOException { + bh.consume(state.searcher.count(new ArrayTermInSetQuery(FIELD, state.termsInput))); + } + + /** + * State for {@code equals*} benchmarks. Builds equal query pairs so we measure {@code equals()} + * on cache-hit (equal) queries — the hot path the packed-{@code byte[]} fast path targets. + * + *

Strategies under comparison: {@link TermInSetQuery} (Lucene baseline), {@link + * ArrayTermInSetQuery} with VInt-prefix packing, flat-packed {@code byte[]} without boundaries + * (incorrect — distinct term boundaries with the same concatenation collide; included only as a + * perf reference), and packed {@code byte[]} + separate {@code int[]} lengths array (correct + * alternative shape with two {@code memcmp}s). + */ + @State(Scope.Benchmark) + public static class EqualsState { + + @Param({"300", "3000", "30000"}) + public int numTerms; + + Query termInSetA; + Query termInSetB; + Query arrayTermInSetA; + Query arrayTermInSetB; + + byte[] flatPackedA; + byte[] flatPackedB; + int flatHashA; + int flatHashB; + + byte[] lengthsPackedA; + byte[] lengthsPackedB; + int[] termLengthsA; + int[] termLengthsB; + int lengthsHashA; + int lengthsHashB; + + @Setup(Level.Trial) + public void setup() { + Random rng = new Random(42); + BytesRef[] sorted = generateSortedTerms(rng, numTerms); + + List termsList = Arrays.asList(sorted); + termInSetA = new TermInSetQuery(FIELD, termsList); + termInSetB = new TermInSetQuery(FIELD, termsList); + + List copyA = new ArrayList<>(termsList.size()); + List copyB = new ArrayList<>(termsList.size()); + for (BytesRef t : sorted) { + copyA.add(BytesRef.deepCopyOf(t)); + copyB.add(BytesRef.deepCopyOf(t)); + } + arrayTermInSetA = new ArrayTermInSetQuery(FIELD, copyA); + arrayTermInSetB = new ArrayTermInSetQuery(FIELD, copyB); + + flatPackedA = flatPack(sorted); + flatPackedB = flatPack(sorted); + flatHashA = Arrays.hashCode(flatPackedA); + flatHashB = Arrays.hashCode(flatPackedB); + + lengthsPackedA = flatPack(sorted); + lengthsPackedB = flatPack(sorted); + termLengthsA = extractLengths(sorted); + termLengthsB = extractLengths(sorted); + lengthsHashA = 31 * Arrays.hashCode(lengthsPackedA) + Arrays.hashCode(termLengthsA); + lengthsHashB = 31 * Arrays.hashCode(lengthsPackedB) + Arrays.hashCode(termLengthsB); + } + + private static BytesRef[] generateSortedTerms(Random rng, int count) { + BytesRef[] terms = new BytesRef[count]; + for (int i = 0; i < count; i++) { + terms[i] = new BytesRef(String.format(Locale.ROOT, "%016x", rng.nextLong())); + } + Arrays.sort(terms); + return terms; + } + + private static byte[] flatPack(BytesRef[] terms) { + int total = 0; + for (BytesRef t : terms) { + total += t.length; + } + byte[] packed = new byte[total]; + int pos = 0; + for (BytesRef t : terms) { + System.arraycopy(t.bytes, t.offset, packed, pos, t.length); + pos += t.length; + } + return packed; + } + + private static int[] extractLengths(BytesRef[] terms) { + int[] lengths = new int[terms.length]; + for (int i = 0; i < terms.length; i++) { + lengths[i] = terms[i].length; + } + return lengths; + } + } + + @Benchmark + public void equalsTermInSetQuery(EqualsState state, Blackhole bh) { + bh.consume(state.termInSetA.equals(state.termInSetB)); + } + + /** VInt-prefix packed — what {@link ArrayTermInSetQuery} ships. Single {@code Arrays.equals}. */ + @Benchmark + public void equalsArrayTermInSetQuery(EqualsState state, Blackhole bh) { + bh.consume(state.arrayTermInSetA.equals(state.arrayTermInSetB)); + } + + /** + * Flat packed without term boundaries — incorrect (distinct term boundaries with the same + * concatenation collide). Included only as a perf reference for the equals fast path. + */ + @Benchmark + public void equalsFlatPacked(EqualsState state, Blackhole bh) { + bh.consume( + state.flatHashA == state.flatHashB && Arrays.equals(state.flatPackedA, state.flatPackedB)); + } + + /** Packed {@code byte[]} + separate {@code int[]} lengths — correct alternative shape. */ + @Benchmark + public void equalsPackedPlusLengths(EqualsState state, Blackhole bh) { + bh.consume( + state.lengthsHashA == state.lengthsHashB + && Arrays.equals(state.lengthsPackedA, state.lengthsPackedB) + && Arrays.equals(state.termLengthsA, state.termLengthsB)); + } +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ArrayTermInSetQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ArrayTermInSetQuery.java new file mode 100644 index 000000000000..3afe4dbd508e --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/ArrayTermInSetQuery.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.search; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.SortedSet; +import org.apache.lucene.index.FilteredTermsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.MultiTermQuery; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.TermInSetQuery; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.AttributeSource; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.BytesRefComparator; +import org.apache.lucene.util.BytesRefIterator; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.StringSorter; +import org.apache.lucene.util.automaton.Automata; +import org.apache.lucene.util.automaton.Automaton; +import org.apache.lucene.util.automaton.ByteRunAutomaton; + +/** + * Array-backed alternative to {@link TermInSetQuery} that stores its terms as a sorted {@code + * BytesRef[]} (zero-copy views into a single packed {@code byte[]}) instead of {@link + * org.apache.lucene.index.PrefixCodedTerms}. Trades a bit of RAM for cheaper per-segment iteration + * and a vectorized {@link #equals}/{@link #hashCode} fast path on the packed bytes. + * + *

Constructed and used identically to {@link TermInSetQuery} — accepts an arbitrary {@link + * Collection} of terms and sorts/deduplicates internally. If the input is a {@link SortedSet} with + * natural-order comparator, the sort is skipped (same fast path as {@link TermInSetQuery}). + * + *

Useful in CPU-bound workloads with large term sets (tens of thousands of terms) where many + * segments are too small to benefit from the {@link org.apache.lucene.search.LRUQueryCache}, so the + * per-segment unpacking cost of {@link org.apache.lucene.index.PrefixCodedTerms} dominates. + */ +public class ArrayTermInSetQuery extends MultiTermQuery implements Accountable { + + private static final long BASE_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(ArrayTermInSetQuery.class); + + private static final long BYTES_REF_SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(BytesRef.class); + + private final byte[] packedTerms; + private final BytesRef[] terms; + private final int cachedHashCode; + private final long ramBytesUsed; + + /** + * Constructs a query matching documents containing any of the given terms in the given field. + * + *

Terms are MSB radix-sorted and adjacent duplicates removed, then packed into a single + * contiguous {@code byte[]} with VInt-length prefixes encoding term boundaries. The + * length-prefixed packing makes the representation canonically unique so {@link #equals} and + * {@link #hashCode} can operate on the raw {@code byte[]} via a single vectorized memcmp. + */ + public ArrayTermInSetQuery(String field, Collection terms) { + super(field, CONSTANT_SCORE_BLENDED_REWRITE); + + BytesRef[] sortedTerms = terms.toArray(new BytesRef[0]); + // Already sorted if we are a SortedSet with natural-order comparator. Same O(1) + // fast path TermInSetQuery uses. + boolean sorted = + terms instanceof SortedSet && ((SortedSet) terms).comparator() == null; + if (sorted == false) { + new StringSorter(BytesRefComparator.NATURAL) { + @Override + protected void get(BytesRefBuilder builder, BytesRef result, int i) { + BytesRef t = sortedTerms[i]; + result.bytes = t.bytes; + result.offset = t.offset; + result.length = t.length; + } + + @Override + protected void swap(int i, int j) { + BytesRef tmp = sortedTerms[i]; + sortedTerms[i] = sortedTerms[j]; + sortedTerms[j] = tmp; + } + }.sort(0, sortedTerms.length); + } + + // Walk once to count uniques and total packed length, then a second time to + // write packed bytes + views. Two passes give us exact-sized allocations + // without an intermediate growable buffer or an explicit dedup'd array. + int uniqueCount = 0; + int totalLen = 0; + BytesRef previous = null; + for (BytesRef t : sortedTerms) { + if (previous != null && previous.equals(t)) { + continue; + } + uniqueCount++; + totalLen += vIntSize(t.length) + t.length; + previous = t; + } + + byte[] packed = new byte[totalLen]; + BytesRef[] views = new BytesRef[uniqueCount]; + int pos = 0; + int idx = 0; + previous = null; + for (BytesRef t : sortedTerms) { + if (previous != null && previous.equals(t)) { + continue; + } + pos = writeVInt(packed, pos, t.length); + System.arraycopy(t.bytes, t.offset, packed, pos, t.length); + views[idx++] = new BytesRef(packed, pos, t.length); + pos += t.length; + previous = t; + } + + this.packedTerms = packed; + this.terms = views; + this.cachedHashCode = Arrays.hashCode(packed); + this.ramBytesUsed = + BASE_RAM_BYTES_USED + + RamUsageEstimator.sizeOf(packed) + + RamUsageEstimator.shallowSizeOf(views) + + (long) views.length * BYTES_REF_SHALLOW_SIZE; + } + + private static int vIntSize(int value) { + int size = 1; + while ((value & ~0x7F) != 0) { + size++; + value >>>= 7; + } + return size; + } + + private static int writeVInt(byte[] dest, int pos, int value) { + while ((value & ~0x7F) != 0) { + dest[pos++] = (byte) ((value & 0x7F) | 0x80); + value >>>= 7; + } + dest[pos++] = (byte) value; + return pos; + } + + @Override + protected TermsEnum getTermsEnum(Terms terms, AttributeSource atts) throws IOException { + return new ArraySetEnum(terms.iterator(), this.terms); + } + + @Override + public long getTermsCount() { + return terms.length; + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(getField()) == false) { + return; + } + if (terms.length == 1) { + visitor.consumeTerms(this, new Term(getField(), terms[0])); + } + if (terms.length > 1) { + visitor.consumeTermsMatching(this, getField(), this::asByteRunAutomaton); + } + } + + private ByteRunAutomaton asByteRunAutomaton() { + try { + Automaton a = Automata.makeBinaryStringUnion(new ArrayBytesRefIterator(terms)); + return new ByteRunAutomaton(a, true); + } catch (IOException e) { + // Shouldn't happen: ArrayBytesRefIterator's next() never throws. + throw new UncheckedIOException(e); + } + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(ArrayTermInSetQuery other) { + return cachedHashCode == other.cachedHashCode + && getField().equals(other.getField()) + && Arrays.equals(packedTerms, other.packedTerms); + } + + @Override + public int hashCode() { + return 31 * classHash() + cachedHashCode; + } + + @Override + public String toString(String defaultField) { + StringBuilder sb = new StringBuilder(); + sb.append(getField()).append(":("); + for (int i = 0; i < terms.length; i++) { + if (i > 0) { + sb.append(' '); + } + sb.append(Term.toString(terms[i])); + } + sb.append(')'); + return sb.toString(); + } + + @Override + public long ramBytesUsed() { + return ramBytesUsed; + } + + @Override + public Collection getChildResources() { + return Collections.emptyList(); + } + + /** + * Same ping-pong logic as {@code TermInSetQuery.SetEnum} but backed by a pre-decoded {@link + * BytesRef}{@code []} instead of a streaming {@code PrefixCodedTerms.TermIterator}. + */ + private static class ArraySetEnum extends FilteredTermsEnum { + private final BytesRef[] terms; + private int idx = 0; + + ArraySetEnum(TermsEnum termsEnum, BytesRef[] terms) { + super(termsEnum); + this.terms = terms; + } + + private BytesRef currentTerm() { + return idx < terms.length ? terms[idx] : null; + } + + @Override + protected AcceptStatus accept(BytesRef term) throws IOException { + int cmp = 0; + while (idx < terms.length && (cmp = terms[idx].compareTo(term)) < 0) { + idx++; + } + if (idx >= terms.length) { + return AcceptStatus.END; + } else if (cmp == 0) { + return AcceptStatus.YES_AND_SEEK; + } else { + return AcceptStatus.NO_AND_SEEK; + } + } + + @Override + protected BytesRef nextSeekTerm(BytesRef currentTerm) throws IOException { + if (currentTerm == null) { + return currentTerm(); + } + while (idx < terms.length && terms[idx].compareTo(currentTerm) <= 0) { + idx++; + } + return currentTerm(); + } + } + + private static class ArrayBytesRefIterator implements BytesRefIterator { + private final BytesRef[] terms; + private int idx = 0; + + ArrayBytesRefIterator(BytesRef[] terms) { + this.terms = terms; + } + + @Override + public BytesRef next() { + return idx < terms.length ? terms[idx++] : null; + } + } +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestArrayTermInSetQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestArrayTermInSetQuery.java new file mode 100644 index 000000000000..5558dfd983cb --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestArrayTermInSetQuery.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.sandbox.search; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.TermInSetQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.automaton.ByteRunnable; + +public class TestArrayTermInSetQuery extends LuceneTestCase { + + private static final String FIELD = "group"; + + public void testMatchesSameDocsAsTermInSetQuery() throws IOException { + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + + addDoc(writer, "alpha"); + addDoc(writer, "beta"); + addDoc(writer, "gamma"); + addDoc(writer, "delta"); + writer.commit(); + addDoc(writer, "epsilon"); + addDoc(writer, "zeta"); + writer.commit(); + + DirectoryReader reader = DirectoryReader.open(writer); + IndexSearcher searcher = newSearcher(reader); + writer.close(); + + List queryTerms = + Arrays.asList(new BytesRef("alpha"), new BytesRef("gamma"), new BytesRef("zeta")); + + TopDocs vanillaResults = searcher.search(new TermInSetQuery(FIELD, queryTerms), 100); + TopDocs arrayResults = searcher.search(new ArrayTermInSetQuery(FIELD, queryTerms), 100); + + assertEquals(vanillaResults.scoreDocs.length, arrayResults.scoreDocs.length); + for (int i = 0; i < vanillaResults.scoreDocs.length; i++) { + assertEquals(vanillaResults.scoreDocs[i].doc, arrayResults.scoreDocs[i].doc); + assertEquals(vanillaResults.scoreDocs[i].score, arrayResults.scoreDocs[i].score, 0.0f); + } + + reader.close(); + dir.close(); + } + + public void testReturnsNoResultsForNonMatchingTerms() throws IOException { + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + addDoc(writer, "alpha"); + addDoc(writer, "beta"); + writer.commit(); + + DirectoryReader reader = DirectoryReader.open(writer); + IndexSearcher searcher = newSearcher(reader); + writer.close(); + + TopDocs results = + searcher.search( + new ArrayTermInSetQuery( + FIELD, Arrays.asList(new BytesRef("nonexistent"), new BytesRef("bogus"))), + 100); + assertEquals(0, results.scoreDocs.length); + + reader.close(); + dir.close(); + } + + public void testWorksWithSingleTerm() throws IOException { + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + addDoc(writer, "alpha"); + addDoc(writer, "beta"); + addDoc(writer, "alpha"); + writer.commit(); + + DirectoryReader reader = DirectoryReader.open(writer); + IndexSearcher searcher = newSearcher(reader); + writer.close(); + + TopDocs results = + searcher.search(new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"))), 100); + assertEquals(2, results.scoreDocs.length); + + reader.close(); + dir.close(); + } + + public void testWorksAcrossMultipleSegments() throws IOException { + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + + for (int seg = 0; seg < 5; seg++) { + for (int doc = 0; doc < 20; doc++) { + addDoc(writer, "term_" + (seg * 20 + doc)); + } + writer.commit(); + } + + DirectoryReader reader = DirectoryReader.open(writer); + IndexSearcher searcher = newSearcher(reader); + writer.close(); + + List queryTerms = + Arrays.asList( + new BytesRef("term_0"), + new BytesRef("term_25"), + new BytesRef("term_50"), + new BytesRef("term_75"), + new BytesRef("term_99")); + + TopDocs vanillaResults = searcher.search(new TermInSetQuery(FIELD, queryTerms), 100); + TopDocs arrayResults = searcher.search(new ArrayTermInSetQuery(FIELD, queryTerms), 100); + + assertEquals(vanillaResults.scoreDocs.length, arrayResults.scoreDocs.length); + + reader.close(); + dir.close(); + } + + public void testEqualQueriesHaveSameHashCode() { + ArrayTermInSetQuery a = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"))); + ArrayTermInSetQuery b = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"))); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + } + + public void testSortIsInternal() { + // Same set of terms in different input orders must produce equal queries. + ArrayTermInSetQuery sorted = + new ArrayTermInSetQuery( + FIELD, + Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"), new BytesRef("gamma"))); + ArrayTermInSetQuery reversed = + new ArrayTermInSetQuery( + FIELD, + Arrays.asList(new BytesRef("gamma"), new BytesRef("beta"), new BytesRef("alpha"))); + + assertEquals(sorted, reversed); + assertEquals(sorted.hashCode(), reversed.hashCode()); + } + + public void testDuplicatesInInputAreDeduplicated() { + ArrayTermInSetQuery deduped = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"))); + ArrayTermInSetQuery withDupes = + new ArrayTermInSetQuery( + FIELD, + Arrays.asList( + new BytesRef("alpha"), + new BytesRef("alpha"), + new BytesRef("beta"), + new BytesRef("beta"), + new BytesRef("beta"))); + + assertEquals(deduped, withDupes); + assertEquals(2, deduped.getTermsCount()); + assertEquals(2, withDupes.getTermsCount()); + } + + public void testDifferentTermsAreNotEqual() { + ArrayTermInSetQuery a = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"))); + ArrayTermInSetQuery b = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("gamma"))); + + assertNotEquals(a, b); + } + + public void testDifferentFieldsAreNotEqual() { + ArrayTermInSetQuery a = + new ArrayTermInSetQuery("field_a", Arrays.asList(new BytesRef("alpha"))); + ArrayTermInSetQuery b = + new ArrayTermInSetQuery("field_b", Arrays.asList(new BytesRef("alpha"))); + + assertNotEquals(a, b); + } + + public void testToStringContainsFieldAndTerms() { + ArrayTermInSetQuery query = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"))); + String str = query.toString(); + assertTrue("toString should contain field name, got: " + str, str.contains(FIELD)); + assertTrue("toString should contain term 'alpha', got: " + str, str.contains("alpha")); + assertTrue("toString should contain term 'beta', got: " + str, str.contains("beta")); + } + + public void testVisitCallsConsumeTermsForSingleTerm() { + ArrayTermInSetQuery query = new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("only"))); + AtomicBoolean consumed = new AtomicBoolean(false); + + query.visit( + new QueryVisitor() { + @Override + public void consumeTerms(Query query, Term... terms) { + consumed.set(true); + assertEquals(1, terms.length); + assertEquals(FIELD, terms[0].field()); + assertEquals("only", terms[0].text()); + } + }); + + assertTrue("consumeTerms should have been called for single-term query", consumed.get()); + } + + public void testVisitCallsConsumeTermsMatchingForMultipleTerms() { + ArrayTermInSetQuery query = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("alpha"), new BytesRef("beta"))); + AtomicBoolean consumed = new AtomicBoolean(false); + + query.visit( + new QueryVisitor() { + @Override + public void consumeTermsMatching( + Query query, String field, Supplier automaton) { + consumed.set(true); + assertEquals(FIELD, field); + ByteRunnable bra = automaton.get(); + assertTrue( + "automaton should accept 'alpha'", bra.run(new BytesRef("alpha").bytes, 0, 5)); + assertTrue("automaton should accept 'beta'", bra.run(new BytesRef("beta").bytes, 0, 4)); + } + }); + + assertTrue("consumeTermsMatching should have been called for multi-term query", consumed.get()); + } + + public void testRamBytesUsedIsPositiveAndScalesWithTermCount() { + ArrayTermInSetQuery small = new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("a"))); + ArrayTermInSetQuery large = new ArrayTermInSetQuery(FIELD, Arrays.asList(buildTerms(100))); + + assertTrue("ramBytesUsed should be positive", small.ramBytesUsed() > 0); + assertTrue("more terms should use more RAM", large.ramBytesUsed() > small.ramBytesUsed()); + } + + public void testGetTermsCountReturnsNumberOfTerms() { + ArrayTermInSetQuery query = + new ArrayTermInSetQuery( + FIELD, Arrays.asList(new BytesRef("a"), new BytesRef("b"), new BytesRef("c"))); + assertEquals(3, query.getTermsCount()); + } + + public void testDifferentTermBoundariesWithSameConcatenationAreNotEqual() { + // ["a","bc"] and ["ab","c"] have identical concatenated bytes "abc" but different term + // boundaries. VInt-length prefixes in the packed representation must distinguish them. + ArrayTermInSetQuery q1 = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("a"), new BytesRef("bc"))); + ArrayTermInSetQuery q2 = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("ab"), new BytesRef("c"))); + + assertNotEquals(q1, q2); + } + + public void testDifferentTermCountWithSameConcatenationAreNotEqual() { + // ["abc"] vs ["a","bc"] — same bytes, different term count + ArrayTermInSetQuery q1 = new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("abc"))); + ArrayTermInSetQuery q2 = + new ArrayTermInSetQuery(FIELD, Arrays.asList(new BytesRef("a"), new BytesRef("bc"))); + + assertNotEquals(q1, q2); + } + + public void testPackedBackingWorksWithNonZeroOffsetBytesRefs() throws IOException { + Directory dir = newDirectory(); + IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig()); + addDoc(writer, "alpha"); + addDoc(writer, "beta"); + writer.commit(); + + DirectoryReader reader = DirectoryReader.open(writer); + IndexSearcher searcher = newSearcher(reader); + writer.close(); + + // BytesRef with non-zero offset into a larger backing array + byte[] buf = "XXXalphaYYY".getBytes(StandardCharsets.UTF_8); + BytesRef offsetRef = new BytesRef(buf, 3, 5); + List terms = Arrays.asList(offsetRef, new BytesRef("beta")); + + TopDocs results = searcher.search(new ArrayTermInSetQuery(FIELD, terms), 100); + assertEquals(2, results.scoreDocs.length); + + // Equality still works across independently constructed queries that happen to have one + // input with a non-zero-offset BytesRef. + List terms2 = Arrays.asList(new BytesRef("alpha"), new BytesRef("beta")); + ArrayTermInSetQuery q1 = new ArrayTermInSetQuery(FIELD, terms); + ArrayTermInSetQuery q2 = new ArrayTermInSetQuery(FIELD, terms2); + assertEquals(q1, q2); + assertEquals(q1.hashCode(), q2.hashCode()); + + reader.close(); + dir.close(); + } + + private static BytesRef[] buildTerms(int count) { + BytesRef[] terms = new BytesRef[count]; + for (int i = 0; i < count; i++) { + terms[i] = new BytesRef("term_" + String.format(Locale.ROOT, "%04d", i)); + } + return terms; + } + + private static void addDoc(IndexWriter writer, String value) throws IOException { + Document doc = new Document(); + doc.add(new StringField(FIELD, value, Field.Store.YES)); + writer.addDocument(doc); + } +}