diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 87ebb03bee77..6c77dd0774a1 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -335,6 +335,8 @@ Improvements * GITHUB#15574, GITHUB#15995: Introduce TopGroupsCollectorManager to parallelize search when using TopGroupsCollector. (Binlong Gao) +* GITHUB#15565: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao) + * GITHUB#15989: DocValuesRangeIterator always tries to use skipper-based block iteration as its approximation. (Alan Woodward) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java index 4ff5e16c96d4..013f83b3f3c6 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java @@ -232,6 +232,20 @@ protected void setNextReader(LeafReaderContext ctx) throws IOException { * @throws IOException If I/O related errors occur */ protected abstract void updateDocHead(int doc) throws IOException; + + /** + * Returns the sort values for this group head. + * + * @return the sort values, or null if not stored + */ + protected abstract Object[] getSortValues(); + + /** + * Returns the field comparators used to determine the group head ordering. + * + * @return the comparators, one per sort field + */ + protected abstract FieldComparator[] getComparators(); } /** General implementation using a {@link FieldComparator} to select the group head */ @@ -252,6 +266,7 @@ private static class SortingGroupHead extends GroupHead { final FieldComparator[] comparators; final LeafFieldComparator[] leafComparators; + final Object[] sortValues; protected SortingGroupHead( Sort sort, T groupValue, int doc, LeafReaderContext context, Scorable scorer) @@ -260,12 +275,14 @@ protected SortingGroupHead( final SortField[] sortFields = sort.getSort(); comparators = new FieldComparator[sortFields.length]; leafComparators = new LeafFieldComparator[sortFields.length]; + sortValues = new Object[sortFields.length]; for (int i = 0; i < sortFields.length; i++) { comparators[i] = sortFields[i].getComparator(1, Pruning.NONE); leafComparators[i] = comparators[i].getLeafComparator(context); leafComparators[i].setScorer(scorer); leafComparators[i].copy(0, doc); leafComparators[i].setBottom(0); + sortValues[i] = comparators[i].value(0); } } @@ -291,12 +308,23 @@ public int compare(int compIDX, int doc) throws IOException { @Override public void updateDocHead(int doc) throws IOException { - for (LeafFieldComparator comparator : leafComparators) { - comparator.copy(0, doc); - comparator.setBottom(0); + for (int i = 0; i < leafComparators.length; i++) { + leafComparators[i].copy(0, doc); + leafComparators[i].setBottom(0); + sortValues[i] = comparators[i].value(0); } this.doc = doc + docBase; } + + @Override + protected Object[] getSortValues() { + return sortValues; + } + + @Override + protected FieldComparator[] getComparators() { + return comparators; + } } /** Specialized implementation for sorting by score */ @@ -317,12 +345,14 @@ private static class ScoringGroupHead extends GroupHead { private Scorable scorer; private float topScore; + private final Object[] sortValues; protected ScoringGroupHead(Scorable scorer, T groupValue, int doc, int docBase) throws IOException { super(groupValue, doc, docBase); this.scorer = scorer; this.topScore = scorer.score(); + this.sortValues = new Object[] {topScore}; } @Override @@ -344,6 +374,17 @@ protected int compare(int compIDX, int doc) throws IOException { @Override protected void updateDocHead(int doc) throws IOException { this.doc = doc + docBase; + sortValues[0] = topScore; + } + + @Override + protected Object[] getSortValues() { + return sortValues; + } + + @Override + protected FieldComparator[] getComparators() { + return null; } } } diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java new file mode 100644 index 000000000000..ee3cfd8c5a99 --- /dev/null +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.grouping; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; + +/** + * A {@link CollectorManager} implementation for {@link AllGroupHeadsCollector} that collects the + * most relevant document (group head) for each group across multiple segments and merges the + * per-segment results into a single {@link GroupHeadsResult}. + * + *

Example usage: + * + *

+ * IndexSearcher searcher = ...; // your IndexSearcher
+ * AllGroupHeadsCollectorManager<BytesRef> manager =
+ *     new AllGroupHeadsCollectorManager<>(
+ *         () -> new TermGroupSelector("category"), Sort.RELEVANCE);
+ * GroupHeadsResult result = searcher.search(new MatchAllDocsQuery(), manager);
+ * Bits groupHeadsBits = result.retrieveGroupHeads(searcher.getIndexReader().maxDoc());
+ * 
+ * + * @param the type of the group value + * @lucene.experimental + */ +public class AllGroupHeadsCollectorManager + implements CollectorManager< + AllGroupHeadsCollector, AllGroupHeadsCollectorManager.GroupHeadsResult> { + + /** Holds the merged group heads and provides access as an {@code int[]} or {@link Bits}. */ + public static class GroupHeadsResult { + private final int[] groupHeads; + + private GroupHeadsResult(int[] groupHeads) { + this.groupHeads = groupHeads; + } + + /** Returns the group head document IDs as an array. */ + public int[] retrieveGroupHeads() { + return groupHeads; + } + + /** + * Returns the group head document IDs as a {@link Bits} set of size {@code maxDoc}, suitable + * for use as a filter. + * + * @param maxDoc The maxDoc of the top level {@link IndexReader}. + */ + public Bits retrieveGroupHeads(int maxDoc) { + FixedBitSet result = new FixedBitSet(maxDoc); + for (int docId : groupHeads) { + result.set(docId); + } + return result; + } + } + + private static final class GroupHeadWithValues { + int doc; + final Object[] sortValues; + + GroupHeadWithValues(int doc, Object[] sortValues) { + this.doc = doc; + this.sortValues = sortValues; + } + } + + private final Supplier> groupSelectorFactory; + private final Sort sortWithinGroup; + + /** + * Creates a new AllGroupHeadsCollectorManager. + * + * @param groupSelectorFactory factory to create group selectors for each collector + * @param sortWithinGroup the sort to use within each group to determine the group head + */ + public AllGroupHeadsCollectorManager( + Supplier> groupSelectorFactory, Sort sortWithinGroup) { + this.groupSelectorFactory = groupSelectorFactory; + this.sortWithinGroup = sortWithinGroup; + } + + @Override + public AllGroupHeadsCollector newCollector() throws IOException { + return AllGroupHeadsCollector.newCollector(groupSelectorFactory.get(), sortWithinGroup); + } + + @Override + public GroupHeadsResult reduce(Collection> collectors) { + Map mergedHeads = new HashMap<>(); + SortField[] sortFields = sortWithinGroup.getSort(); + + for (AllGroupHeadsCollector collector : collectors) { + mergeCollectorHeads(collector, mergedHeads, sortFields); + } + + return new GroupHeadsResult(mergedHeads.values().stream().mapToInt(h -> h.doc).toArray()); + } + + private void mergeCollectorHeads( + AllGroupHeadsCollector collector, + Map mergedHeads, + SortField[] sortFields) { + for (AllGroupHeadsCollector.GroupHead head : collector.getCollectedGroupHeads()) { + Object[] sortValues = head.getSortValues(); + GroupHeadWithValues existing = mergedHeads.get(head.groupValue); + if (existing == null || isCompetitive(head, sortValues, existing, sortFields)) { + mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); + } + } + } + + @SuppressWarnings({"rawtypes"}) + private boolean isCompetitive( + AllGroupHeadsCollector.GroupHead head, + Object[] sortValues, + GroupHeadWithValues existing, + SortField[] sortFields) { + FieldComparator[] comparators = head.getComparators(); + int cmp; + // null comparators means sorting by relevance + if (comparators == null) { + cmp = Float.compare((float) sortValues[0], (float) existing.sortValues[0]); + return cmp > 0 || (cmp == 0 && head.doc < existing.doc); + } else { + cmp = 0; + for (int i = 0; i < sortFields.length; i++) { + @SuppressWarnings({"unchecked"}) + int c = comparators[i].compareValues(sortValues[i], existing.sortValues[i]); + c = sortFields[i].getReverse() ? -c : c; + if (c != 0) { + cmp = c; + break; + } + } + return cmp < 0 || (cmp == 0 && head.doc < existing.doc); + } + } +} diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java index a7aa05c14183..10925d378530 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java @@ -136,45 +136,54 @@ public void testBasic() throws Exception { int maxDoc = reader.maxDoc(); Sort sortWithinGroup = new Sort(new SortField("id_1", SortField.Type.INT, true)); - AllGroupHeadsCollector allGroupHeadsCollector = - createRandomCollector(groupField, sortWithinGroup); - indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads())); + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + createRandomCollectorManager(groupField, sortWithinGroup); + AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup); - indexSearcher.search(new TermQuery(new Term("content", "some")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {2, 3, 4}, allGroupHeadsCollector.retrieveGroupHeads())); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "some")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {2, 3, 4}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {2, 3, 4}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {2, 3, 4}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup); - indexSearcher.search(new TermQuery(new Term("content", "blob")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {1, 5}, allGroupHeadsCollector.retrieveGroupHeads())); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "blob")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {1, 5}, groupHeadsResult.retrieveGroupHeads())); assertTrue( - openBitSetContains( - new int[] {1, 5}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + openBitSetContains(new int[] {1, 5}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); // STRING sort type triggers different implementation Sort sortWithinGroup2 = new Sort(new SortField("id_2", SortField.Type.STRING, true)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup2); - indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads())); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup2); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); Sort sortWithinGroup3 = new Sort(new SortField("id_2", SortField.Type.STRING, false)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup3); - indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup3); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); // 7 b/c higher doc id wins, even if order of field is in not in reverse. - assertTrue(arrayContains(new int[] {0, 3, 4, 6}, allGroupHeadsCollector.retrieveGroupHeads())); + assertTrue(arrayContains(new int[] {0, 3, 4, 6}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {0, 3, 4, 6}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {0, 3, 4, 6}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); indexSearcher.getIndexReader().close(); dir.close(); @@ -345,13 +354,14 @@ public void testRandom() throws Exception { final String searchTerm = "real" + random().nextInt(3); boolean sortByScoreOnly = random().nextBoolean(); Sort sortWithinGroup = getRandomSort(sortByScoreOnly); - AllGroupHeadsCollector allGroupHeadsCollector = - createRandomCollector("group", sortWithinGroup); - s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollector); + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + createRandomCollectorManager("group", sortWithinGroup); + AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = + s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); int[] expectedGroupHeads = createExpectedGroupHeads( searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID); - int[] actualGroupHeads = allGroupHeadsCollector.retrieveGroupHeads(); + int[] actualGroupHeads = groupHeadsResult.retrieveGroupHeads(); // The actual group heads contains Lucene ids. Need to change them into our id value. for (int i = 0; i < actualGroupHeads.length; i++) { actualGroupHeads[i] = docIDToFieldId[actualGroupHeads[i]]; @@ -361,7 +371,8 @@ public void testRandom() throws Exception { Arrays.sort(actualGroupHeads); if (VERBOSE) { - System.out.println("Collector: " + allGroupHeadsCollector.getClass().getSimpleName()); + System.out.println( + "CollectorManager: " + allGroupHeadsCollectorManager.getClass().getSimpleName()); System.out.println("Sort within group: " + sortWithinGroup); System.out.println("Num group: " + numGroups); System.out.println("Num doc: " + numDocs); @@ -548,14 +559,15 @@ private Comparator getComparator( }; } - private AllGroupHeadsCollector createRandomCollector(String groupField, Sort sortWithinGroup) { + private AllGroupHeadsCollectorManager createRandomCollectorManager( + String groupField, Sort sortWithinGroup) { if (random().nextBoolean()) { ValueSource vs = new BytesRefFieldSource(groupField); - return AllGroupHeadsCollector.newCollector( - new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup); + return new AllGroupHeadsCollectorManager<>( + () -> new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup); } else { - return AllGroupHeadsCollector.newCollector( - new TermGroupSelector(groupField), sortWithinGroup); + return new AllGroupHeadsCollectorManager<>( + () -> new TermGroupSelector(groupField), sortWithinGroup); } }