Skip to content

Commit a682654

Browse files
authored
perf(constraint-streams): fused equal indexer (#2346)
1 parent 517b90b commit a682654

33 files changed

Lines changed: 1579 additions & 433 deletions

core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/joiner/BiJoinerComber.java

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,83 +6,87 @@
66

77
import ai.timefold.solver.core.api.score.stream.bi.BiJoiner;
88

9+
import org.jspecify.annotations.NullMarked;
10+
import org.jspecify.annotations.Nullable;
11+
912
/**
1013
* Combs an array of {@link BiJoiner} instances into a mergedJoiner and a mergedFiltering.
1114
*
1215
* @param <A>
1316
* @param <B>
1417
*/
18+
@NullMarked
1519
public final class BiJoinerComber<A, B> {
1620

21+
@SafeVarargs
1722
public static <A, B> BiJoinerComber<A, B> comb(BiJoiner<A, B>... joiners) {
18-
List<DefaultBiJoiner<A, B>> defaultJoinerList = new ArrayList<>(joiners.length);
19-
List<BiPredicate<A, B>> filteringList = new ArrayList<>(joiners.length);
23+
var defaultJoinerList = new ArrayList<DefaultBiJoiner<A, B>>(joiners.length);
24+
var filteringList = new ArrayList<BiPredicate<A, B>>(joiners.length);
2025

21-
int indexOfFirstFilter = -1;
26+
var indexOfFirstFilter = -1;
2227
// Make sure all indexing joiners, if any, come before filtering joiners. This is necessary for performance.
23-
for (int i = 0; i < joiners.length; i++) {
24-
BiJoiner<A, B> joiner = joiners[i];
28+
for (var i = 0; i < joiners.length; i++) {
29+
var joiner = joiners[i];
2530
if (joiner instanceof FilteringBiJoiner) {
2631
// From now on, only allow filtering joiners.
2732
indexOfFirstFilter = i;
2833
filteringList.add(((FilteringBiJoiner<A, B>) joiner).getFilter());
2934
} else if (joiner instanceof DefaultBiJoiner) {
3035
if (indexOfFirstFilter >= 0) {
31-
throw new IllegalStateException("Indexing joiner (" + joiner + ") must not follow " +
32-
"a filtering joiner (" + joiners[indexOfFirstFilter] + ").\n" +
33-
"Maybe reorder the joiners such that filtering() joiners are later in the parameter list.");
36+
throw new IllegalStateException("""
37+
Indexing joiner (%s) must not follow a filtering joiner (%s).
38+
Maybe reorder the joiners such that filtering() joiners are later in the parameter list."""
39+
.formatted(joiner, joiners[indexOfFirstFilter]));
3440
}
3541
defaultJoinerList.add((DefaultBiJoiner<A, B>) joiner);
3642
} else {
37-
throw new IllegalArgumentException("The joiner class (" + joiner.getClass() + ") is not supported.");
43+
throw new IllegalArgumentException("The joiner class (%s) is not supported."
44+
.formatted(joiner.getClass()));
3845
}
3946
}
40-
DefaultBiJoiner<A, B> mergedJoiner = DefaultBiJoiner.merge(defaultJoinerList);
41-
BiPredicate<A, B> mergedFiltering = mergeFiltering(filteringList);
47+
var mergedJoiner = DefaultBiJoiner.merge(defaultJoinerList);
48+
var mergedFiltering = mergeFiltering(filteringList);
4249
return new BiJoinerComber<>(mergedJoiner, mergedFiltering);
4350
}
4451

45-
private static <A, B> BiPredicate<A, B> mergeFiltering(List<BiPredicate<A, B>> filteringList) {
46-
if (filteringList.isEmpty()) {
47-
return null;
48-
}
49-
switch (filteringList.size()) {
50-
case 1:
51-
return filteringList.get(0);
52-
case 2:
53-
return filteringList.get(0).and(filteringList.get(1));
54-
default:
55-
// Avoid predicate.and() when more than 2 predicates for debugging and potentially performance
56-
return (A a, B b) -> {
57-
for (BiPredicate<A, B> predicate : filteringList) {
52+
@SuppressWarnings("unchecked")
53+
private static <A, B> @Nullable BiPredicate<A, B> mergeFiltering(List<BiPredicate<A, B>> filteringList) {
54+
return switch (filteringList.size()) {
55+
case 0 -> null;
56+
case 1 -> filteringList.getFirst();
57+
default -> {
58+
// Avoid predicate.and() for debugging and potential performance
59+
var filteringArray = filteringList.toArray(new BiPredicate[0]);
60+
yield (A a, B b) -> {
61+
for (var predicate : filteringArray) {
5862
if (!predicate.test(a, b)) {
5963
return false;
6064
}
6165
}
6266
return true;
6367
};
64-
}
68+
}
69+
};
6570
}
6671

6772
private DefaultBiJoiner<A, B> mergedJoiner;
68-
private final BiPredicate<A, B> mergedFiltering;
73+
private final @Nullable BiPredicate<A, B> mergedFiltering;
6974

70-
public BiJoinerComber(DefaultBiJoiner<A, B> mergedJoiner, BiPredicate<A, B> mergedFiltering) {
75+
public BiJoinerComber(DefaultBiJoiner<A, B> mergedJoiner, @Nullable BiPredicate<A, B> mergedFiltering) {
7176
this.mergedJoiner = mergedJoiner;
7277
this.mergedFiltering = mergedFiltering;
7378
}
7479

7580
/**
76-
* @return never null
81+
* Returns the merged indexing joiner,
82+
* reordered equal-first so the indexer chain always has its (merged) equal level at the top.
83+
* Computed on read to also cover {@link #addJoiner} appends.
7784
*/
7885
public DefaultBiJoiner<A, B> getMergedJoiner() {
79-
return mergedJoiner;
86+
return mergedJoiner.reorderedEqualsFirst();
8087
}
8188

82-
/**
83-
* @return null if not applicable
84-
*/
85-
public BiPredicate<A, B> getMergedFiltering() {
89+
public @Nullable BiPredicate<A, B> getMergedFiltering() {
8690
return mergedFiltering;
8791
}
8892

core/src/main/java/ai/timefold/solver/core/impl/bavet/bi/joiner/DefaultBiJoiner.java

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
import ai.timefold.solver.core.impl.bavet.common.joiner.AbstractJoiner;
1010
import ai.timefold.solver.core.impl.bavet.common.joiner.JoinerType;
1111

12-
import org.jspecify.annotations.NonNull;
12+
import org.jspecify.annotations.NullMarked;
1313

14-
public final class DefaultBiJoiner<A, B> extends AbstractJoiner<B> implements BiJoiner<A, B> {
14+
@NullMarked
15+
public final class DefaultBiJoiner<A, B>
16+
extends AbstractJoiner<B>
17+
implements BiJoiner<A, B> {
1518

1619
private static final DefaultBiJoiner NONE = new DefaultBiJoiner(new Function[0], new JoinerType[0], new Function[0]);
1720

@@ -28,40 +31,64 @@ public DefaultBiJoiner(Function<A, ?>[] leftMappings, JoinerType[] joinerTypes,
2831
}
2932

3033
public static <A, B> DefaultBiJoiner<A, B> merge(List<DefaultBiJoiner<A, B>> joinerList) {
31-
if (joinerList.size() == 1) {
32-
return joinerList.get(0);
33-
}
34-
return joinerList.stream().reduce(NONE, DefaultBiJoiner::and);
34+
return switch (joinerList.size()) {
35+
case 0 -> NONE;
36+
case 1 -> joinerList.getFirst();
37+
default -> joinerList.stream().reduce(NONE, DefaultBiJoiner::and);
38+
};
3539
}
3640

3741
@Override
38-
public @NonNull DefaultBiJoiner<A, B> and(@NonNull BiJoiner<A, B> otherJoiner) {
39-
DefaultBiJoiner<A, B> castJoiner = (DefaultBiJoiner<A, B>) otherJoiner;
40-
int joinerCount = getJoinerCount();
41-
int castJoinerCount = castJoiner.getJoinerCount();
42-
int newJoinerCount = joinerCount + castJoinerCount;
43-
JoinerType[] newJoinerTypes = Arrays.copyOf(this.joinerTypes, newJoinerCount);
44-
Function[] newLeftMappings = Arrays.copyOf(this.leftMappings, newJoinerCount);
45-
Function[] newRightMappings = Arrays.copyOf(this.rightMappings, newJoinerCount);
46-
for (int i = 0; i < castJoinerCount; i++) {
47-
int newJoinerIndex = i + joinerCount;
42+
public DefaultBiJoiner<A, B> and(BiJoiner<A, B> otherJoiner) {
43+
var castJoiner = (DefaultBiJoiner<A, B>) otherJoiner;
44+
var joinerCount = getJoinerCount();
45+
var castJoinerCount = castJoiner.getJoinerCount();
46+
var newJoinerCount = joinerCount + castJoinerCount;
47+
var newJoinerTypes = Arrays.copyOf(this.joinerTypes, newJoinerCount);
48+
var newLeftMappings = Arrays.copyOf(this.leftMappings, newJoinerCount);
49+
var newRightMappings = Arrays.copyOf(this.rightMappings, newJoinerCount);
50+
for (var i = 0; i < castJoinerCount; i++) {
51+
var newJoinerIndex = i + joinerCount;
4852
newJoinerTypes[newJoinerIndex] = castJoiner.getJoinerType(i);
4953
newLeftMappings[newJoinerIndex] = castJoiner.getLeftMapping(i);
5054
newRightMappings[newJoinerIndex] = castJoiner.getRightMapping(i);
5155
}
5256
return new DefaultBiJoiner<>(newLeftMappings, newJoinerTypes, newRightMappings);
5357
}
5458

59+
/**
60+
* @return this if already equal-first (or single joiner); otherwise a copy with all
61+
* {@link JoinerType#EQUAL} joiners moved to the front (stable, see
62+
* {@link AbstractJoiner#equalsFirstSortedPositions}).
63+
*/
64+
public DefaultBiJoiner<A, B> reorderedEqualsFirst() {
65+
var order = equalsFirstSortedPositions(joinerTypes);
66+
if (order == null) {
67+
return this;
68+
}
69+
var count = order.length;
70+
var newLeftMappings = new Function[count];
71+
var newJoinerTypes = new JoinerType[count];
72+
var newRightMappings = new Function[count];
73+
for (var i = 0; i < count; i++) {
74+
var from = order[i];
75+
newLeftMappings[i] = leftMappings[from];
76+
newJoinerTypes[i] = joinerTypes[from];
77+
newRightMappings[i] = rightMappings[from];
78+
}
79+
return new DefaultBiJoiner<>(newLeftMappings, newJoinerTypes, newRightMappings);
80+
}
81+
5582
public Function<A, Object> getLeftMapping(int index) {
5683
return leftMappings[index];
5784
}
5885

5986
public boolean matches(A a, B b) {
60-
int joinerCount = getJoinerCount();
61-
for (int i = 0; i < joinerCount; i++) {
62-
JoinerType joinerType = getJoinerType(i);
63-
Object leftMapping = getLeftMapping(i).apply(a);
64-
Object rightMapping = getRightMapping(i).apply(b);
87+
var joinerCount = getJoinerCount();
88+
for (var i = 0; i < joinerCount; i++) {
89+
var joinerType = getJoinerType(i);
90+
var leftMapping = getLeftMapping(i).apply(a);
91+
var rightMapping = getRightMapping(i).apply(b);
6592
if (!joinerType.matches(leftMapping, rightMapping)) {
6693
return false;
6794
}

0 commit comments

Comments
 (0)