From 1df552e15161704039ff40119678f401e7469d88 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Tue, 13 Jan 2026 15:31:23 +0800 Subject: [PATCH 01/11] Introduce AllGroupHeadsCollectorManager Signed-off-by: Binlong Gao --- lucene/CHANGES.txt | 2 + .../grouping/AllGroupHeadsCollector.java | 94 ++++++++-- .../AllGroupHeadsCollectorManager.java | 171 ++++++++++++++++++ .../grouping/TestAllGroupHeadsCollector.java | 74 ++++---- 4 files changed, 295 insertions(+), 46 deletions(-) create mode 100644 lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index fc9c93c614cf..8f02c28181f9 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -91,6 +91,8 @@ Improvements * GITHUB#15453: Avoid unnecessary sorting and instantiations in readMapOfStrings. (Benjamin Lerer) +* GITHUB#15557: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao) + Optimizations --------------------- * GITHUB#14011: Reduce allocation rate in HNSW concurrent merge. (Viliam Durina) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java index 4ff5e16c96d4..8f90b7728d01 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java @@ -45,6 +45,7 @@ public abstract class AllGroupHeadsCollector extends SimpleCollector { private final GroupSelector groupSelector; protected final Sort sort; + protected final boolean fillSortValues; protected final int[] reversed; protected final int compIDXEnd; @@ -62,15 +63,29 @@ public abstract class AllGroupHeadsCollector extends SimpleCollector { * @param the group value type */ public static AllGroupHeadsCollector newCollector(GroupSelector selector, Sort sort) { + return newCollector(selector, sort, false); + } + + /** + * Create a new AllGroupHeadsCollector based on the type of within-group Sort required + * + * @param selector a GroupSelector to define the groups + * @param sort the within-group sort to use to choose the group head document + * @param fillSortValues whether to store sort values for merging across collectors + * @param the group value type + */ + public static AllGroupHeadsCollector newCollector( + GroupSelector selector, Sort sort, boolean fillSortValues) { if (sort.equals(Sort.RELEVANCE)) { - return new ScoringGroupHeadsCollector<>(selector, sort); + return new ScoringGroupHeadsCollector<>(selector, sort, fillSortValues); } - return new SortingGroupHeadsCollector<>(selector, sort); + return new SortingGroupHeadsCollector<>(selector, sort, fillSortValues); } - private AllGroupHeadsCollector(GroupSelector selector, Sort sort) { + private AllGroupHeadsCollector(GroupSelector selector, Sort sort, boolean fillSortValues) { this.groupSelector = selector; this.sort = sort; + this.fillSortValues = fillSortValues; this.reversed = new int[sort.getSort().length]; final SortField[] sortFields = sort.getSort(); for (int i = 0; i < sortFields.length; i++) { @@ -126,6 +141,17 @@ protected Collection> getCollectedGroupHeads() { return heads.values(); } + /** + * Returns the sort values for a given group. + * + * @param groupValue the group value + * @return the sort values, or null if not available + */ + public Object[] getSortValues(T groupValue) { + GroupHead head = heads.get(groupValue); + return head != null ? head.getSortValues() : null; + } + @Override public void collect(int doc) throws IOException { groupSelector.advanceTo(doc); @@ -232,19 +258,29 @@ protected void setNextReader(LeafReaderContext ctx) throws IOException { * @throws IOException If I/O related errors occur */ protected abstract void updateDocHead(int doc) throws IOException; + + /** + * Returns the sort values for this group head. + * + * @return the sort values, or null if not stored + */ + protected Object[] getSortValues() { + return null; + } } /** General implementation using a {@link FieldComparator} to select the group head */ private static class SortingGroupHeadsCollector extends AllGroupHeadsCollector { - protected SortingGroupHeadsCollector(GroupSelector selector, Sort sort) { - super(selector, sort); + protected SortingGroupHeadsCollector( + GroupSelector selector, Sort sort, boolean fillSortValues) { + super(selector, sort, fillSortValues); } @Override protected GroupHead newGroupHead(int doc, T value, LeafReaderContext ctx, Scorable scorer) throws IOException { - return new SortingGroupHead<>(sort, value, doc, ctx, scorer); + return new SortingGroupHead<>(sort, value, doc, ctx, scorer, fillSortValues); } } @@ -252,20 +288,30 @@ private static class SortingGroupHead extends GroupHead { final FieldComparator[] comparators; final LeafFieldComparator[] leafComparators; + final Object[] sortValues; protected SortingGroupHead( - Sort sort, T groupValue, int doc, LeafReaderContext context, Scorable scorer) + Sort sort, + T groupValue, + int doc, + LeafReaderContext context, + Scorable scorer, + boolean fillSortValues) throws IOException { super(groupValue, doc, context.docBase); final SortField[] sortFields = sort.getSort(); comparators = new FieldComparator[sortFields.length]; leafComparators = new LeafFieldComparator[sortFields.length]; + sortValues = fillSortValues ? new Object[sortFields.length] : null; for (int i = 0; i < sortFields.length; i++) { comparators[i] = sortFields[i].getComparator(1, Pruning.NONE); leafComparators[i] = comparators[i].getLeafComparator(context); leafComparators[i].setScorer(scorer); leafComparators[i].copy(0, doc); leafComparators[i].setBottom(0); + if (fillSortValues) { + sortValues[i] = comparators[i].value(0); + } } } @@ -291,25 +337,34 @@ public int compare(int compIDX, int doc) throws IOException { @Override public void updateDocHead(int doc) throws IOException { - for (LeafFieldComparator comparator : leafComparators) { - comparator.copy(0, doc); - comparator.setBottom(0); + for (int i = 0; i < leafComparators.length; i++) { + leafComparators[i].copy(0, doc); + leafComparators[i].setBottom(0); + if (sortValues != null) { + sortValues[i] = comparators[i].value(0); + } } this.doc = doc + docBase; } + + @Override + protected Object[] getSortValues() { + return sortValues; + } } /** Specialized implementation for sorting by score */ private static class ScoringGroupHeadsCollector extends AllGroupHeadsCollector { - protected ScoringGroupHeadsCollector(GroupSelector selector, Sort sort) { - super(selector, sort); + protected ScoringGroupHeadsCollector( + GroupSelector selector, Sort sort, boolean fillSortValues) { + super(selector, sort, fillSortValues); } @Override protected GroupHead newGroupHead( int doc, T value, LeafReaderContext context, Scorable scorer) throws IOException { - return new ScoringGroupHead<>(scorer, value, doc, context.docBase); + return new ScoringGroupHead<>(scorer, value, doc, context.docBase, fillSortValues); } } @@ -317,12 +372,15 @@ private static class ScoringGroupHead extends GroupHead { private Scorable scorer; private float topScore; + private final Object[] sortValues; - protected ScoringGroupHead(Scorable scorer, T groupValue, int doc, int docBase) + protected ScoringGroupHead( + Scorable scorer, T groupValue, int doc, int docBase, boolean fillSortValues) throws IOException { super(groupValue, doc, docBase); this.scorer = scorer; this.topScore = scorer.score(); + this.sortValues = fillSortValues ? new Object[] {topScore} : null; } @Override @@ -344,6 +402,14 @@ protected int compare(int compIDX, int doc) throws IOException { @Override protected void updateDocHead(int doc) throws IOException { this.doc = doc + docBase; + if (sortValues != null) { + sortValues[0] = topScore; + } + } + + @Override + protected Object[] getSortValues() { + return sortValues; } } } diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java new file mode 100644 index 000000000000..c11e6e900420 --- /dev/null +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.grouping; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.queries.function.ValueSource; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; + +/** + * A CollectorManager implementation for AllGroupHeadsCollector. + * + * @lucene.experimental + */ +public class AllGroupHeadsCollectorManager + implements CollectorManager< + AllGroupHeadsCollector, AllGroupHeadsCollectorManager.GroupHeadsResult> { + + /** Result wrapper that allows retrieving group heads as int[] or Bits. */ + public static class GroupHeadsResult { + private final int[] groupHeads; + + GroupHeadsResult(int[] groupHeads) { + this.groupHeads = groupHeads; + } + + public int[] retrieveGroupHeads() { + return groupHeads; + } + + public Bits retrieveGroupHeads(int maxDoc) { + FixedBitSet result = new FixedBitSet(maxDoc); + for (int docId : groupHeads) { + result.set(docId); + } + return result; + } + } + + private static class GroupHeadWithValues { + final T groupValue; + int doc; + final Object[] sortValues; + + GroupHeadWithValues(T groupValue, int doc, Object[] sortValues) { + this.groupValue = groupValue; + this.doc = doc; + this.sortValues = sortValues; + } + } + + private final String groupField; + private final ValueSource valueSource; + private final Map valueSourceContext; + private final Sort sortWithinGroup; + + /** Creates a new AllGroupHeadsCollectorManager for TermGroupSelector. */ + public AllGroupHeadsCollectorManager(String groupField, Sort sortWithinGroup) { + this.groupField = groupField; + this.valueSource = null; + this.valueSourceContext = null; + this.sortWithinGroup = sortWithinGroup; + } + + /** Creates a new AllGroupHeadsCollectorManager for ValueSourceGroupSelector. */ + public AllGroupHeadsCollectorManager( + ValueSource valueSource, Map valueSourceContext, Sort sortWithinGroup) { + this.groupField = null; + this.valueSource = valueSource; + this.valueSourceContext = valueSourceContext; + this.sortWithinGroup = sortWithinGroup; + } + + @Override + public AllGroupHeadsCollector newCollector() throws IOException { + GroupSelector newGroupSelector; + if (groupField != null) { + newGroupSelector = new TermGroupSelector(groupField); + } else { + newGroupSelector = new ValueSourceGroupSelector(valueSource, valueSourceContext); + } + + return AllGroupHeadsCollector.newCollector(newGroupSelector, sortWithinGroup, true); + } + + @Override + public GroupHeadsResult reduce(Collection> collectors) { + if (collectors.isEmpty()) { + return new GroupHeadsResult(new int[0]); + } + + if (collectors.size() == 1) { + return new GroupHeadsResult(collectors.iterator().next().retrieveGroupHeads()); + } + + Map> mergedHeads = new HashMap<>(); + SortField[] sortFields = sortWithinGroup.getSort(); + + for (AllGroupHeadsCollector collector : collectors) { + mergeCollectorHeads(collector, mergedHeads, sortFields); + } + + return new GroupHeadsResult(mergedHeads.values().stream().mapToInt(h -> h.doc).toArray()); + } + + @SuppressWarnings("unchecked") + private void mergeCollectorHeads( + AllGroupHeadsCollector collector, + Map> mergedHeads, + SortField[] sortFields) { + Collection> heads = + (Collection>) collector.getCollectedGroupHeads(); + for (AllGroupHeadsCollector.GroupHead head : heads) { + Object[] sortValues = collector.getSortValues(head.groupValue); + GroupHeadWithValues existing = mergedHeads.get(head.groupValue); + if (existing == null) { + mergedHeads.put( + head.groupValue, new GroupHeadWithValues<>(head.groupValue, head.doc, sortValues)); + } else if (sortValues != null && existing.sortValues != null) { + int cmp = compareValues(sortValues, existing.sortValues, sortFields); + if (cmp > 0 || (cmp == 0 && head.doc < existing.doc)) { + mergedHeads.put( + head.groupValue, new GroupHeadWithValues<>(head.groupValue, head.doc, sortValues)); + } + } + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private int compareValues(Object[] values1, Object[] values2, SortField[] sortFields) { + for (int i = 0; i < sortFields.length; i++) { + int cmp = 0; + if (values1[i] == null) { + cmp = values2[i] == null ? 0 : -1; + } else if (values2[i] == null) { + cmp = 1; + } else if (values1[i] instanceof Comparable) { + cmp = ((Comparable) values1[i]).compareTo(values2[i]); + } + if (cmp != 0) { + // For SCORE type, natural order is descending (higher is better) + // For other types, natural order is ascending (lower is better) + // reverse=true flips the natural order + boolean naturalDescending = sortFields[i].getType() == SortField.Type.SCORE; + boolean wantDescending = naturalDescending != sortFields[i].getReverse(); + return wantDescending ? cmp : -cmp; + } + } + return 0; + } +} diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java index 2e54b1158f2e..2ba22bbdaf88 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java @@ -136,45 +136,54 @@ public void testBasic() throws Exception { int maxDoc = reader.maxDoc(); Sort sortWithinGroup = new Sort(new SortField("id_1", SortField.Type.INT, true)); - AllGroupHeadsCollector allGroupHeadsCollector = - createRandomCollector(groupField, sortWithinGroup); - indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads())); + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + createRandomCollectorManager(groupField, sortWithinGroup); + AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup); - indexSearcher.search(new TermQuery(new Term("content", "some")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {2, 3, 4}, allGroupHeadsCollector.retrieveGroupHeads())); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "some")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {2, 3, 4}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {2, 3, 4}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {2, 3, 4}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup); - indexSearcher.search(new TermQuery(new Term("content", "blob")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {1, 5}, allGroupHeadsCollector.retrieveGroupHeads())); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "blob")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {1, 5}, groupHeadsResult.retrieveGroupHeads())); assertTrue( - openBitSetContains( - new int[] {1, 5}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + openBitSetContains(new int[] {1, 5}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); // STRING sort type triggers different implementation Sort sortWithinGroup2 = new Sort(new SortField("id_2", SortField.Type.STRING, true)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup2); - indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector); - assertTrue(arrayContains(new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads())); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup2); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); + assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {2, 3, 5, 7}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); Sort sortWithinGroup3 = new Sort(new SortField("id_2", SortField.Type.STRING, false)); - allGroupHeadsCollector = createRandomCollector(groupField, sortWithinGroup3); - indexSearcher.search(new TermQuery(new Term("content", "random")), allGroupHeadsCollector); + allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup3); + groupHeadsResult = + indexSearcher.search( + new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); // 7 b/c higher doc id wins, even if order of field is in not in reverse. - assertTrue(arrayContains(new int[] {0, 3, 4, 6}, allGroupHeadsCollector.retrieveGroupHeads())); + assertTrue(arrayContains(new int[] {0, 3, 4, 6}, groupHeadsResult.retrieveGroupHeads())); assertTrue( openBitSetContains( - new int[] {0, 3, 4, 6}, allGroupHeadsCollector.retrieveGroupHeads(maxDoc), maxDoc)); + new int[] {0, 3, 4, 6}, groupHeadsResult.retrieveGroupHeads(maxDoc), maxDoc)); indexSearcher.getIndexReader().close(); dir.close(); @@ -345,13 +354,14 @@ public void testRandom() throws Exception { final String searchTerm = "real" + random().nextInt(3); boolean sortByScoreOnly = random().nextBoolean(); Sort sortWithinGroup = getRandomSort(sortByScoreOnly); - AllGroupHeadsCollector allGroupHeadsCollector = - createRandomCollector("group", sortWithinGroup); - s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollector); + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + new AllGroupHeadsCollectorManager("group", sortWithinGroup); + AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = + s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); int[] expectedGroupHeads = createExpectedGroupHeads( searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID); - int[] actualGroupHeads = allGroupHeadsCollector.retrieveGroupHeads(); + int[] actualGroupHeads = groupHeadsResult.retrieveGroupHeads(); // The actual group heads contains Lucene ids. Need to change them into our id value. for (int i = 0; i < actualGroupHeads.length; i++) { actualGroupHeads[i] = docIDToFieldId[actualGroupHeads[i]]; @@ -361,7 +371,8 @@ public void testRandom() throws Exception { Arrays.sort(actualGroupHeads); if (VERBOSE) { - System.out.println("Collector: " + allGroupHeadsCollector.getClass().getSimpleName()); + System.out.println( + "CollectorManager: " + allGroupHeadsCollectorManager.getClass().getSimpleName()); System.out.println("Sort within group: " + sortWithinGroup); System.out.println("Num group: " + numGroups); System.out.println("Num doc: " + numDocs); @@ -548,14 +559,13 @@ private Comparator getComparator( }; } - private AllGroupHeadsCollector createRandomCollector(String groupField, Sort sortWithinGroup) { + private AllGroupHeadsCollectorManager createRandomCollectorManager( + String groupField, Sort sortWithinGroup) { if (random().nextBoolean()) { ValueSource vs = new BytesRefFieldSource(groupField); - return AllGroupHeadsCollector.newCollector( - new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup); + return new AllGroupHeadsCollectorManager(vs, new HashMap<>(), sortWithinGroup); } else { - return AllGroupHeadsCollector.newCollector( - new TermGroupSelector(groupField), sortWithinGroup); + return new AllGroupHeadsCollectorManager(groupField, sortWithinGroup); } } From 75314b437762eb707d174eef55aab9fd7abaaa65 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Tue, 13 Jan 2026 15:52:21 +0800 Subject: [PATCH 02/11] Modify change log Signed-off-by: Binlong Gao --- lucene/CHANGES.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 8f02c28181f9..b29bc8cc1263 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -91,7 +91,7 @@ Improvements * GITHUB#15453: Avoid unnecessary sorting and instantiations in readMapOfStrings. (Benjamin Lerer) -* GITHUB#15557: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao) +* GITHUB#15565: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao) Optimizations --------------------- From b60e77cf97832bf04f5dd25fae478514ab39ea18 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Tue, 13 Jan 2026 16:26:22 +0800 Subject: [PATCH 03/11] Remove unused field in GroupHeadWithValue Signed-off-by: Binlong Gao --- .../AllGroupHeadsCollectorManager.java | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index c11e6e900420..86b7cf15d3aa 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -57,13 +57,11 @@ public Bits retrieveGroupHeads(int maxDoc) { } } - private static class GroupHeadWithValues { - final T groupValue; + private static class GroupHeadWithValues { int doc; final Object[] sortValues; - GroupHeadWithValues(T groupValue, int doc, Object[] sortValues) { - this.groupValue = groupValue; + GroupHeadWithValues(int doc, Object[] sortValues) { this.doc = doc; this.sortValues = sortValues; } @@ -113,7 +111,7 @@ public GroupHeadsResult reduce(Collection> collectors) return new GroupHeadsResult(collectors.iterator().next().retrieveGroupHeads()); } - Map> mergedHeads = new HashMap<>(); + Map mergedHeads = new HashMap<>(); SortField[] sortFields = sortWithinGroup.getSort(); for (AllGroupHeadsCollector collector : collectors) { @@ -126,21 +124,19 @@ public GroupHeadsResult reduce(Collection> collectors) @SuppressWarnings("unchecked") private void mergeCollectorHeads( AllGroupHeadsCollector collector, - Map> mergedHeads, + Map mergedHeads, SortField[] sortFields) { Collection> heads = (Collection>) collector.getCollectedGroupHeads(); for (AllGroupHeadsCollector.GroupHead head : heads) { Object[] sortValues = collector.getSortValues(head.groupValue); - GroupHeadWithValues existing = mergedHeads.get(head.groupValue); + GroupHeadWithValues existing = mergedHeads.get(head.groupValue); if (existing == null) { - mergedHeads.put( - head.groupValue, new GroupHeadWithValues<>(head.groupValue, head.doc, sortValues)); + mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); } else if (sortValues != null && existing.sortValues != null) { int cmp = compareValues(sortValues, existing.sortValues, sortFields); if (cmp > 0 || (cmp == 0 && head.doc < existing.doc)) { - mergedHeads.put( - head.groupValue, new GroupHeadWithValues<>(head.groupValue, head.doc, sortValues)); + mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); } } } From 65a0478e97ca983d17fbe68f32d2d4b4d6ceac34 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Wed, 15 Apr 2026 18:10:49 +0800 Subject: [PATCH 04/11] x --- .../AllGroupHeadsCollectorManager.java | 52 +++++++------------ .../grouping/TestAllGroupHeadsCollector.java | 24 ++++----- 2 files changed, 30 insertions(+), 46 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index 86b7cf15d3aa..ee94d18cb17e 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -20,7 +20,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.Map; -import org.apache.lucene.queries.function.ValueSource; +import java.util.function.Supplier; import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; @@ -32,9 +32,9 @@ * * @lucene.experimental */ -public class AllGroupHeadsCollectorManager +public class AllGroupHeadsCollectorManager implements CollectorManager< - AllGroupHeadsCollector, AllGroupHeadsCollectorManager.GroupHeadsResult> { + AllGroupHeadsCollector, AllGroupHeadsCollectorManager.GroupHeadsResult> { /** Result wrapper that allows retrieving group heads as int[] or Bits. */ public static class GroupHeadsResult { @@ -67,42 +67,28 @@ private static class GroupHeadWithValues { } } - private final String groupField; - private final ValueSource valueSource; - private final Map valueSourceContext; + private final Supplier> groupSelectorFactory; private final Sort sortWithinGroup; - /** Creates a new AllGroupHeadsCollectorManager for TermGroupSelector. */ - public AllGroupHeadsCollectorManager(String groupField, Sort sortWithinGroup) { - this.groupField = groupField; - this.valueSource = null; - this.valueSourceContext = null; - this.sortWithinGroup = sortWithinGroup; - } - - /** Creates a new AllGroupHeadsCollectorManager for ValueSourceGroupSelector. */ + /** + * Creates a new AllGroupHeadsCollectorManager. + * + * @param groupSelectorFactory factory to create group selectors for each collector + * @param sortWithinGroup the sort to use within each group to determine the group head + */ public AllGroupHeadsCollectorManager( - ValueSource valueSource, Map valueSourceContext, Sort sortWithinGroup) { - this.groupField = null; - this.valueSource = valueSource; - this.valueSourceContext = valueSourceContext; + Supplier> groupSelectorFactory, Sort sortWithinGroup) { + this.groupSelectorFactory = groupSelectorFactory; this.sortWithinGroup = sortWithinGroup; } @Override - public AllGroupHeadsCollector newCollector() throws IOException { - GroupSelector newGroupSelector; - if (groupField != null) { - newGroupSelector = new TermGroupSelector(groupField); - } else { - newGroupSelector = new ValueSourceGroupSelector(valueSource, valueSourceContext); - } - - return AllGroupHeadsCollector.newCollector(newGroupSelector, sortWithinGroup, true); + public AllGroupHeadsCollector newCollector() throws IOException { + return AllGroupHeadsCollector.newCollector(groupSelectorFactory.get(), sortWithinGroup, true); } @Override - public GroupHeadsResult reduce(Collection> collectors) { + public GroupHeadsResult reduce(Collection> collectors) { if (collectors.isEmpty()) { return new GroupHeadsResult(new int[0]); } @@ -114,20 +100,18 @@ public GroupHeadsResult reduce(Collection> collectors) Map mergedHeads = new HashMap<>(); SortField[] sortFields = sortWithinGroup.getSort(); - for (AllGroupHeadsCollector collector : collectors) { + for (AllGroupHeadsCollector collector : collectors) { mergeCollectorHeads(collector, mergedHeads, sortFields); } return new GroupHeadsResult(mergedHeads.values().stream().mapToInt(h -> h.doc).toArray()); } - @SuppressWarnings("unchecked") - private void mergeCollectorHeads( + private void mergeCollectorHeads( AllGroupHeadsCollector collector, Map mergedHeads, SortField[] sortFields) { - Collection> heads = - (Collection>) collector.getCollectedGroupHeads(); + Collection> heads = collector.getCollectedGroupHeads(); for (AllGroupHeadsCollector.GroupHead head : heads) { Object[] sortValues = collector.getSortValues(head.groupValue); GroupHeadWithValues existing = mergedHeads.get(head.groupValue); diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java index 2ba22bbdaf88..a5887c6d855f 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java @@ -136,10 +136,10 @@ public void testBasic() throws Exception { int maxDoc = reader.maxDoc(); Sort sortWithinGroup = new Sort(new SortField("id_1", SortField.Type.INT, true)); - AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = - indexSearcher.search( + (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -148,7 +148,7 @@ public void testBasic() throws Exception { allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); groupHeadsResult = - indexSearcher.search( + (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( new TermQuery(new Term("content", "some")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {2, 3, 4}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -157,7 +157,7 @@ public void testBasic() throws Exception { allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); groupHeadsResult = - indexSearcher.search( + (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( new TermQuery(new Term("content", "blob")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {1, 5}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -167,7 +167,7 @@ public void testBasic() throws Exception { Sort sortWithinGroup2 = new Sort(new SortField("id_2", SortField.Type.STRING, true)); allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup2); groupHeadsResult = - indexSearcher.search( + (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -177,7 +177,7 @@ public void testBasic() throws Exception { Sort sortWithinGroup3 = new Sort(new SortField("id_2", SortField.Type.STRING, false)); allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup3); groupHeadsResult = - indexSearcher.search( + (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); // 7 b/c higher doc id wins, even if order of field is in not in reverse. assertTrue(arrayContains(new int[] {0, 3, 4, 6}, groupHeadsResult.retrieveGroupHeads())); @@ -354,10 +354,10 @@ public void testRandom() throws Exception { final String searchTerm = "real" + random().nextInt(3); boolean sortByScoreOnly = random().nextBoolean(); Sort sortWithinGroup = getRandomSort(sortByScoreOnly); - AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = - new AllGroupHeadsCollectorManager("group", sortWithinGroup); + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + new AllGroupHeadsCollectorManager<>(() -> new TermGroupSelector("group"), sortWithinGroup); AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = - s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); + (AllGroupHeadsCollectorManager.GroupHeadsResult) s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); int[] expectedGroupHeads = createExpectedGroupHeads( searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID); @@ -559,13 +559,13 @@ private Comparator getComparator( }; } - private AllGroupHeadsCollectorManager createRandomCollectorManager( + private AllGroupHeadsCollectorManager createRandomCollectorManager( String groupField, Sort sortWithinGroup) { if (random().nextBoolean()) { ValueSource vs = new BytesRefFieldSource(groupField); - return new AllGroupHeadsCollectorManager(vs, new HashMap<>(), sortWithinGroup); + return new AllGroupHeadsCollectorManager<>(() -> new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup); } else { - return new AllGroupHeadsCollectorManager(groupField, sortWithinGroup); + return new AllGroupHeadsCollectorManager(() -> new TermGroupSelector(groupField), sortWithinGroup); } } From 17cc63715a564b17195d4c7308a8fa032c95bec2 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Tue, 28 Apr 2026 14:52:21 +0800 Subject: [PATCH 05/11] Optimize some code Signed-off-by: Binlong Gao --- .../AllGroupHeadsCollectorManager.java | 11 ++-------- .../grouping/TestAllGroupHeadsCollector.java | 22 +++++++++++-------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index ee94d18cb17e..5b84b9aabc2f 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -89,14 +89,6 @@ public AllGroupHeadsCollector newCollector() throws IOException { @Override public GroupHeadsResult reduce(Collection> collectors) { - if (collectors.isEmpty()) { - return new GroupHeadsResult(new int[0]); - } - - if (collectors.size() == 1) { - return new GroupHeadsResult(collectors.iterator().next().retrieveGroupHeads()); - } - Map mergedHeads = new HashMap<>(); SortField[] sortFields = sortWithinGroup.getSort(); @@ -111,7 +103,8 @@ private void mergeCollectorHeads( AllGroupHeadsCollector collector, Map mergedHeads, SortField[] sortFields) { - Collection> heads = collector.getCollectedGroupHeads(); + Collection> heads = + collector.getCollectedGroupHeads(); for (AllGroupHeadsCollector.GroupHead head : heads) { Object[] sortValues = collector.getSortValues(head.groupValue); GroupHeadWithValues existing = mergedHeads.get(head.groupValue); diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java index a5887c6d855f..4bbf46f10a28 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java @@ -139,7 +139,7 @@ public void testBasic() throws Exception { AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = - (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( + indexSearcher.search( new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -148,7 +148,7 @@ public void testBasic() throws Exception { allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); groupHeadsResult = - (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( + indexSearcher.search( new TermQuery(new Term("content", "some")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {2, 3, 4}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -157,7 +157,7 @@ public void testBasic() throws Exception { allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup); groupHeadsResult = - (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( + indexSearcher.search( new TermQuery(new Term("content", "blob")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {1, 5}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -167,7 +167,7 @@ public void testBasic() throws Exception { Sort sortWithinGroup2 = new Sort(new SortField("id_2", SortField.Type.STRING, true)); allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup2); groupHeadsResult = - (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( + indexSearcher.search( new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); assertTrue(arrayContains(new int[] {2, 3, 5, 7}, groupHeadsResult.retrieveGroupHeads())); assertTrue( @@ -177,7 +177,7 @@ public void testBasic() throws Exception { Sort sortWithinGroup3 = new Sort(new SortField("id_2", SortField.Type.STRING, false)); allGroupHeadsCollectorManager = createRandomCollectorManager(groupField, sortWithinGroup3); groupHeadsResult = - (AllGroupHeadsCollectorManager.GroupHeadsResult) indexSearcher.search( + indexSearcher.search( new TermQuery(new Term("content", "random")), allGroupHeadsCollectorManager); // 7 b/c higher doc id wins, even if order of field is in not in reverse. assertTrue(arrayContains(new int[] {0, 3, 4, 6}, groupHeadsResult.retrieveGroupHeads())); @@ -355,9 +355,10 @@ public void testRandom() throws Exception { boolean sortByScoreOnly = random().nextBoolean(); Sort sortWithinGroup = getRandomSort(sortByScoreOnly); AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = - new AllGroupHeadsCollectorManager<>(() -> new TermGroupSelector("group"), sortWithinGroup); + new AllGroupHeadsCollectorManager<>( + () -> new TermGroupSelector("group"), sortWithinGroup); AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = - (AllGroupHeadsCollectorManager.GroupHeadsResult) s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); + s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); int[] expectedGroupHeads = createExpectedGroupHeads( searchTerm, groupDocs, sortWithinGroup, sortByScoreOnly, fieldIdToDocID); @@ -563,9 +564,12 @@ private AllGroupHeadsCollectorManager createRandomCollectorManager( String groupField, Sort sortWithinGroup) { if (random().nextBoolean()) { ValueSource vs = new BytesRefFieldSource(groupField); - return new AllGroupHeadsCollectorManager<>(() -> new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup); + Map context = new HashMap<>(); + return new AllGroupHeadsCollectorManager<>( + () -> new ValueSourceGroupSelector(vs, context), sortWithinGroup); } else { - return new AllGroupHeadsCollectorManager(() -> new TermGroupSelector(groupField), sortWithinGroup); + return new AllGroupHeadsCollectorManager<>( + () -> new TermGroupSelector(groupField), sortWithinGroup); } } From bb3a689af86d7c7229fa29aaf7c44fb6d50df4db Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Tue, 28 Apr 2026 15:24:01 +0800 Subject: [PATCH 06/11] Modify change log Signed-off-by: Binlong Gao --- lucene/CHANGES.txt | 6 ++---- .../search/grouping/AllGroupHeadsCollectorManager.java | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6b4eb32000b9..09632ebbbe48 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -123,10 +123,6 @@ Improvements * GITHUB#15453: Avoid unnecessary sorting and instantiations in readMapOfStrings. (Benjamin Lerer) -* GITHUB#15565: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao) - -* GITHUB#15574: Introduce TopGroupsCollectorManager to parallelize search when using TopGroupsCollector. (Binlong Gao) - * GITHUB#15225: Improve package documentation for org.apache.lucene.util. (Syed Mohammad Saad) * GITHUB#15558: Refactor QueryCache for performance. (Sagar Upadhyaya) @@ -308,6 +304,8 @@ Improvements * GITHUB#15574: Introduce TopGroupsCollectorManager to parallelize search when using TopGroupsCollector. (Binlong Gao) +* GITHUB#15565: Introduce AllGroupHeadsCollectorManager to parallelize search when using AllGroupHeadsCollector. (Binlong Gao) + Optimizations --------------------- * GITHUB#15861: Optimise PhraseScorer by short circuiting non competitive documents in TOP_SCORES mode. (Prithvi S) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index 5b84b9aabc2f..be57da439317 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -127,8 +127,8 @@ private int compareValues(Object[] values1, Object[] values2, SortField[] sortFi cmp = values2[i] == null ? 0 : -1; } else if (values2[i] == null) { cmp = 1; - } else if (values1[i] instanceof Comparable) { - cmp = ((Comparable) values1[i]).compareTo(values2[i]); + } else if (values1[i] instanceof Comparable comparable) { + cmp = comparable.compareTo(values2[i]); } if (cmp != 0) { // For SCORE type, natural order is descending (higher is better) From d71d4f3d3954ba676ccfa14b12a81a79458c4f7a Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Sun, 10 May 2026 21:06:41 +0800 Subject: [PATCH 07/11] Optimize some code Signed-off-by: Binlong Gao --- .../grouping/AllGroupHeadsCollector.java | 15 +----- .../AllGroupHeadsCollectorManager.java | 53 ++++++++++++------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java index 8f90b7728d01..bb7e2385ceaf 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java @@ -141,17 +141,6 @@ protected Collection> getCollectedGroupHeads() { return heads.values(); } - /** - * Returns the sort values for a given group. - * - * @param groupValue the group value - * @return the sort values, or null if not available - */ - public Object[] getSortValues(T groupValue) { - GroupHead head = heads.get(groupValue); - return head != null ? head.getSortValues() : null; - } - @Override public void collect(int doc) throws IOException { groupSelector.advanceTo(doc); @@ -264,9 +253,7 @@ protected void setNextReader(LeafReaderContext ctx) throws IOException { * * @return the sort values, or null if not stored */ - protected Object[] getSortValues() { - return null; - } + protected abstract Object[] getSortValues(); } /** General implementation using a {@link FieldComparator} to select the group head */ diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index be57da439317..311cc71d5338 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -21,33 +21,57 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Supplier; +import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.Pruning; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; /** - * A CollectorManager implementation for AllGroupHeadsCollector. + * A {@link CollectorManager} implementation for {@link AllGroupHeadsCollector} that collects the + * most relevant document (group head) for each group across multiple segments and merges the + * per-segment results into a single {@link GroupHeadsResult}. * + *

Example usage: + * + *

+ * IndexSearcher searcher = ...; // your IndexSearcher
+ * AllGroupHeadsCollectorManager<BytesRef> manager =
+ *     new AllGroupHeadsCollectorManager<>(
+ *         () -> new TermGroupSelector("category"), Sort.RELEVANCE);
+ * GroupHeadsResult result = searcher.search(new MatchAllDocsQuery(), manager);
+ * Bits groupHeadsBits = result.retrieveGroupHeads(searcher.getIndexReader().maxDoc());
+ * 
+ * + * @param the type of the group value * @lucene.experimental */ public class AllGroupHeadsCollectorManager implements CollectorManager< AllGroupHeadsCollector, AllGroupHeadsCollectorManager.GroupHeadsResult> { - /** Result wrapper that allows retrieving group heads as int[] or Bits. */ - public static class GroupHeadsResult { + /** Holds the merged group heads and provides access as an {@code int[]} or {@link Bits}. */ + public static final class GroupHeadsResult { private final int[] groupHeads; - GroupHeadsResult(int[] groupHeads) { + private GroupHeadsResult(int[] groupHeads) { this.groupHeads = groupHeads; } + /** Returns the group head document IDs as an array. */ public int[] retrieveGroupHeads() { return groupHeads; } + /** + * Returns the group head document IDs as a {@link Bits} set of size {@code maxDoc}, suitable + * for use as a filter. + * + * @param maxDoc The maxDoc of the top level {@link IndexReader}. + */ public Bits retrieveGroupHeads(int maxDoc) { FixedBitSet result = new FixedBitSet(maxDoc); for (int docId : groupHeads) { @@ -106,13 +130,13 @@ private void mergeCollectorHeads( Collection> heads = collector.getCollectedGroupHeads(); for (AllGroupHeadsCollector.GroupHead head : heads) { - Object[] sortValues = collector.getSortValues(head.groupValue); + Object[] sortValues = head.getSortValues(); GroupHeadWithValues existing = mergedHeads.get(head.groupValue); if (existing == null) { mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); } else if (sortValues != null && existing.sortValues != null) { int cmp = compareValues(sortValues, existing.sortValues, sortFields); - if (cmp > 0 || (cmp == 0 && head.doc < existing.doc)) { + if (cmp < 0 || (cmp == 0 && head.doc < existing.doc)) { mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); } } @@ -122,21 +146,10 @@ private void mergeCollectorHeads( @SuppressWarnings({"unchecked", "rawtypes"}) private int compareValues(Object[] values1, Object[] values2, SortField[] sortFields) { for (int i = 0; i < sortFields.length; i++) { - int cmp = 0; - if (values1[i] == null) { - cmp = values2[i] == null ? 0 : -1; - } else if (values2[i] == null) { - cmp = 1; - } else if (values1[i] instanceof Comparable comparable) { - cmp = comparable.compareTo(values2[i]); - } + FieldComparator comparator = sortFields[i].getComparator(1, Pruning.NONE); + int cmp = comparator.compareValues(values1[i], values2[i]); if (cmp != 0) { - // For SCORE type, natural order is descending (higher is better) - // For other types, natural order is ascending (lower is better) - // reverse=true flips the natural order - boolean naturalDescending = sortFields[i].getType() == SortField.Type.SCORE; - boolean wantDescending = naturalDescending != sortFields[i].getReverse(); - return wantDescending ? cmp : -cmp; + return sortFields[i].getReverse() ? -cmp : cmp; } } return 0; From 04ef4cf5abd6e5e4059ea7c3cf98d5d38b586ecd Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Mon, 11 May 2026 11:02:20 +0800 Subject: [PATCH 08/11] Fix class modifier Signed-off-by: Binlong Gao --- .../lucene/search/grouping/AllGroupHeadsCollectorManager.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index 311cc71d5338..07811c82b154 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -54,7 +54,7 @@ public class AllGroupHeadsCollectorManager AllGroupHeadsCollector, AllGroupHeadsCollectorManager.GroupHeadsResult> { /** Holds the merged group heads and provides access as an {@code int[]} or {@link Bits}. */ - public static final class GroupHeadsResult { + public static class GroupHeadsResult { private final int[] groupHeads; private GroupHeadsResult(int[] groupHeads) { @@ -81,7 +81,7 @@ public Bits retrieveGroupHeads(int maxDoc) { } } - private static class GroupHeadWithValues { + private static final class GroupHeadWithValues { int doc; final Object[] sortValues; From c319a8d4bb21b8f92960272fbb827b9154b96348 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Wed, 13 May 2026 17:19:51 +0800 Subject: [PATCH 09/11] Optimize some code Signed-off-by: Binlong Gao --- .../grouping/AllGroupHeadsCollector.java | 61 +++++-------------- .../AllGroupHeadsCollectorManager.java | 25 +++++--- .../grouping/TestAllGroupHeadsCollector.java | 5 +- 3 files changed, 35 insertions(+), 56 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java index bb7e2385ceaf..8d713d558fd3 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java @@ -45,7 +45,6 @@ public abstract class AllGroupHeadsCollector extends SimpleCollector { private final GroupSelector groupSelector; protected final Sort sort; - protected final boolean fillSortValues; protected final int[] reversed; protected final int compIDXEnd; @@ -63,29 +62,15 @@ public abstract class AllGroupHeadsCollector extends SimpleCollector { * @param the group value type */ public static AllGroupHeadsCollector newCollector(GroupSelector selector, Sort sort) { - return newCollector(selector, sort, false); - } - - /** - * Create a new AllGroupHeadsCollector based on the type of within-group Sort required - * - * @param selector a GroupSelector to define the groups - * @param sort the within-group sort to use to choose the group head document - * @param fillSortValues whether to store sort values for merging across collectors - * @param the group value type - */ - public static AllGroupHeadsCollector newCollector( - GroupSelector selector, Sort sort, boolean fillSortValues) { if (sort.equals(Sort.RELEVANCE)) { - return new ScoringGroupHeadsCollector<>(selector, sort, fillSortValues); + return new ScoringGroupHeadsCollector<>(selector, sort); } - return new SortingGroupHeadsCollector<>(selector, sort, fillSortValues); + return new SortingGroupHeadsCollector<>(selector, sort); } - private AllGroupHeadsCollector(GroupSelector selector, Sort sort, boolean fillSortValues) { + private AllGroupHeadsCollector(GroupSelector selector, Sort sort) { this.groupSelector = selector; this.sort = sort; - this.fillSortValues = fillSortValues; this.reversed = new int[sort.getSort().length]; final SortField[] sortFields = sort.getSort(); for (int i = 0; i < sortFields.length; i++) { @@ -259,15 +244,14 @@ protected void setNextReader(LeafReaderContext ctx) throws IOException { /** General implementation using a {@link FieldComparator} to select the group head */ private static class SortingGroupHeadsCollector extends AllGroupHeadsCollector { - protected SortingGroupHeadsCollector( - GroupSelector selector, Sort sort, boolean fillSortValues) { - super(selector, sort, fillSortValues); + protected SortingGroupHeadsCollector(GroupSelector selector, Sort sort) { + super(selector, sort); } @Override protected GroupHead newGroupHead(int doc, T value, LeafReaderContext ctx, Scorable scorer) throws IOException { - return new SortingGroupHead<>(sort, value, doc, ctx, scorer, fillSortValues); + return new SortingGroupHead<>(sort, value, doc, ctx, scorer); } } @@ -278,27 +262,20 @@ private static class SortingGroupHead extends GroupHead { final Object[] sortValues; protected SortingGroupHead( - Sort sort, - T groupValue, - int doc, - LeafReaderContext context, - Scorable scorer, - boolean fillSortValues) + Sort sort, T groupValue, int doc, LeafReaderContext context, Scorable scorer) throws IOException { super(groupValue, doc, context.docBase); final SortField[] sortFields = sort.getSort(); comparators = new FieldComparator[sortFields.length]; leafComparators = new LeafFieldComparator[sortFields.length]; - sortValues = fillSortValues ? new Object[sortFields.length] : null; + sortValues = new Object[sortFields.length]; for (int i = 0; i < sortFields.length; i++) { comparators[i] = sortFields[i].getComparator(1, Pruning.NONE); leafComparators[i] = comparators[i].getLeafComparator(context); leafComparators[i].setScorer(scorer); leafComparators[i].copy(0, doc); leafComparators[i].setBottom(0); - if (fillSortValues) { - sortValues[i] = comparators[i].value(0); - } + sortValues[i] = comparators[i].value(0); } } @@ -327,9 +304,7 @@ public void updateDocHead(int doc) throws IOException { for (int i = 0; i < leafComparators.length; i++) { leafComparators[i].copy(0, doc); leafComparators[i].setBottom(0); - if (sortValues != null) { - sortValues[i] = comparators[i].value(0); - } + sortValues[i] = comparators[i].value(0); } this.doc = doc + docBase; } @@ -343,15 +318,14 @@ protected Object[] getSortValues() { /** Specialized implementation for sorting by score */ private static class ScoringGroupHeadsCollector extends AllGroupHeadsCollector { - protected ScoringGroupHeadsCollector( - GroupSelector selector, Sort sort, boolean fillSortValues) { - super(selector, sort, fillSortValues); + protected ScoringGroupHeadsCollector(GroupSelector selector, Sort sort) { + super(selector, sort); } @Override protected GroupHead newGroupHead( int doc, T value, LeafReaderContext context, Scorable scorer) throws IOException { - return new ScoringGroupHead<>(scorer, value, doc, context.docBase, fillSortValues); + return new ScoringGroupHead<>(scorer, value, doc, context.docBase); } } @@ -361,13 +335,12 @@ private static class ScoringGroupHead extends GroupHead { private float topScore; private final Object[] sortValues; - protected ScoringGroupHead( - Scorable scorer, T groupValue, int doc, int docBase, boolean fillSortValues) + protected ScoringGroupHead(Scorable scorer, T groupValue, int doc, int docBase) throws IOException { super(groupValue, doc, docBase); this.scorer = scorer; this.topScore = scorer.score(); - this.sortValues = fillSortValues ? new Object[] {topScore} : null; + this.sortValues = new Object[] {topScore}; } @Override @@ -389,9 +362,7 @@ protected int compare(int compIDX, int doc) throws IOException { @Override protected void updateDocHead(int doc) throws IOException { this.doc = doc + docBase; - if (sortValues != null) { - sortValues[0] = topScore; - } + sortValues[0] = topScore; } @Override diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index 07811c82b154..e6203797a34a 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -23,8 +23,6 @@ import java.util.function.Supplier; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.CollectorManager; -import org.apache.lucene.search.FieldComparator; -import org.apache.lucene.search.Pruning; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.util.Bits; @@ -108,12 +106,12 @@ public AllGroupHeadsCollectorManager( @Override public AllGroupHeadsCollector newCollector() throws IOException { - return AllGroupHeadsCollector.newCollector(groupSelectorFactory.get(), sortWithinGroup, true); + return AllGroupHeadsCollector.newCollector(groupSelectorFactory.get(), sortWithinGroup); } @Override public GroupHeadsResult reduce(Collection> collectors) { - Map mergedHeads = new HashMap<>(); + Map mergedHeads = new HashMap<>(); SortField[] sortFields = sortWithinGroup.getSort(); for (AllGroupHeadsCollector collector : collectors) { @@ -125,7 +123,7 @@ public GroupHeadsResult reduce(Collection> collectors) private void mergeCollectorHeads( AllGroupHeadsCollector collector, - Map mergedHeads, + Map mergedHeads, SortField[] sortFields) { Collection> heads = collector.getCollectedGroupHeads(); @@ -146,10 +144,21 @@ private void mergeCollectorHeads( @SuppressWarnings({"unchecked", "rawtypes"}) private int compareValues(Object[] values1, Object[] values2, SortField[] sortFields) { for (int i = 0; i < sortFields.length; i++) { - FieldComparator comparator = sortFields[i].getComparator(1, Pruning.NONE); - int cmp = comparator.compareValues(values1[i], values2[i]); + int cmp = 0; + if (values1[i] == null) { + cmp = values2[i] == null ? 0 : -1; + } else if (values2[i] == null) { + cmp = 1; + } else if (values1[i] instanceof Comparable) { + cmp = ((Comparable) values1[i]).compareTo(values2[i]); + } if (cmp != 0) { - return sortFields[i].getReverse() ? -cmp : cmp; + // For SCORE type, natural order is descending (higher is better) + // For other types, natural order is ascending (lower is better) + // reverse=true flips the natural order + boolean naturalDescending = sortFields[i].getType() == SortField.Type.SCORE; + boolean wantDescending = naturalDescending != sortFields[i].getReverse(); + return wantDescending ? -cmp : cmp; } } return 0; diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java index 4bbf46f10a28..3a71b9bb2487 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java @@ -354,9 +354,8 @@ public void testRandom() throws Exception { final String searchTerm = "real" + random().nextInt(3); boolean sortByScoreOnly = random().nextBoolean(); Sort sortWithinGroup = getRandomSort(sortByScoreOnly); - AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = - new AllGroupHeadsCollectorManager<>( - () -> new TermGroupSelector("group"), sortWithinGroup); + AllGroupHeadsCollectorManager allGroupHeadsCollectorManager = + createRandomCollectorManager("group", sortWithinGroup); AllGroupHeadsCollectorManager.GroupHeadsResult groupHeadsResult = s.search(new TermQuery(new Term("content", searchTerm)), allGroupHeadsCollectorManager); int[] expectedGroupHeads = From 26225847f34fb4fad4f711053c73aaac6dc91168 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Wed, 20 May 2026 11:22:24 +0800 Subject: [PATCH 10/11] Cover custom FieldComparator case when merging group heads Signed-off-by: Binlong Gao --- .../grouping/AllGroupHeadsCollector.java | 17 ++++++ .../AllGroupHeadsCollectorManager.java | 54 +++++++++---------- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java index 8d713d558fd3..013f83b3f3c6 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollector.java @@ -239,6 +239,13 @@ protected void setNextReader(LeafReaderContext ctx) throws IOException { * @return the sort values, or null if not stored */ protected abstract Object[] getSortValues(); + + /** + * Returns the field comparators used to determine the group head ordering. + * + * @return the comparators, one per sort field + */ + protected abstract FieldComparator[] getComparators(); } /** General implementation using a {@link FieldComparator} to select the group head */ @@ -313,6 +320,11 @@ public void updateDocHead(int doc) throws IOException { protected Object[] getSortValues() { return sortValues; } + + @Override + protected FieldComparator[] getComparators() { + return comparators; + } } /** Specialized implementation for sorting by score */ @@ -369,5 +381,10 @@ protected void updateDocHead(int doc) throws IOException { protected Object[] getSortValues() { return sortValues; } + + @Override + protected FieldComparator[] getComparators() { + return null; + } } } diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java index e6203797a34a..ee3cfd8c5a99 100644 --- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java +++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AllGroupHeadsCollectorManager.java @@ -23,6 +23,7 @@ import java.util.function.Supplier; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.FieldComparator; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.util.Bits; @@ -125,42 +126,39 @@ private void mergeCollectorHeads( AllGroupHeadsCollector collector, Map mergedHeads, SortField[] sortFields) { - Collection> heads = - collector.getCollectedGroupHeads(); - for (AllGroupHeadsCollector.GroupHead head : heads) { + for (AllGroupHeadsCollector.GroupHead head : collector.getCollectedGroupHeads()) { Object[] sortValues = head.getSortValues(); GroupHeadWithValues existing = mergedHeads.get(head.groupValue); - if (existing == null) { + if (existing == null || isCompetitive(head, sortValues, existing, sortFields)) { mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); - } else if (sortValues != null && existing.sortValues != null) { - int cmp = compareValues(sortValues, existing.sortValues, sortFields); - if (cmp < 0 || (cmp == 0 && head.doc < existing.doc)) { - mergedHeads.put(head.groupValue, new GroupHeadWithValues(head.doc, sortValues)); - } } } } - @SuppressWarnings({"unchecked", "rawtypes"}) - private int compareValues(Object[] values1, Object[] values2, SortField[] sortFields) { - for (int i = 0; i < sortFields.length; i++) { - int cmp = 0; - if (values1[i] == null) { - cmp = values2[i] == null ? 0 : -1; - } else if (values2[i] == null) { - cmp = 1; - } else if (values1[i] instanceof Comparable) { - cmp = ((Comparable) values1[i]).compareTo(values2[i]); - } - if (cmp != 0) { - // For SCORE type, natural order is descending (higher is better) - // For other types, natural order is ascending (lower is better) - // reverse=true flips the natural order - boolean naturalDescending = sortFields[i].getType() == SortField.Type.SCORE; - boolean wantDescending = naturalDescending != sortFields[i].getReverse(); - return wantDescending ? -cmp : cmp; + @SuppressWarnings({"rawtypes"}) + private boolean isCompetitive( + AllGroupHeadsCollector.GroupHead head, + Object[] sortValues, + GroupHeadWithValues existing, + SortField[] sortFields) { + FieldComparator[] comparators = head.getComparators(); + int cmp; + // null comparators means sorting by relevance + if (comparators == null) { + cmp = Float.compare((float) sortValues[0], (float) existing.sortValues[0]); + return cmp > 0 || (cmp == 0 && head.doc < existing.doc); + } else { + cmp = 0; + for (int i = 0; i < sortFields.length; i++) { + @SuppressWarnings({"unchecked"}) + int c = comparators[i].compareValues(sortValues[i], existing.sortValues[i]); + c = sortFields[i].getReverse() ? -c : c; + if (c != 0) { + cmp = c; + break; + } } + return cmp < 0 || (cmp == 0 && head.doc < existing.doc); } - return 0; } } From 5c221245e491464dd94bdd1ab6f49c57bd20b7d2 Mon Sep 17 00:00:00 2001 From: Binlong Gao Date: Wed, 20 May 2026 14:24:29 +0800 Subject: [PATCH 11/11] Do not share context across collectors for ValueSourceGroupSelector Signed-off-by: Binlong Gao --- .../lucene/search/grouping/TestAllGroupHeadsCollector.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java index 7f6f4524b25e..10925d378530 100644 --- a/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java +++ b/lucene/grouping/src/test/org/apache/lucene/search/grouping/TestAllGroupHeadsCollector.java @@ -563,9 +563,8 @@ private AllGroupHeadsCollectorManager createRandomCollectorManager( String groupField, Sort sortWithinGroup) { if (random().nextBoolean()) { ValueSource vs = new BytesRefFieldSource(groupField); - Map context = new HashMap<>(); return new AllGroupHeadsCollectorManager<>( - () -> new ValueSourceGroupSelector(vs, context), sortWithinGroup); + () -> new ValueSourceGroupSelector(vs, new HashMap<>()), sortWithinGroup); } else { return new AllGroupHeadsCollectorManager<>( () -> new TermGroupSelector(groupField), sortWithinGroup);