Skip to content

Commit dda0845

Browse files
committed
feat(neighborhoods): add groupBy() support
1 parent 510fc69 commit dda0845

29 files changed

Lines changed: 1447 additions & 197 deletions
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package ai.timefold.solver.core.impl.neighborhood.stream.collector;
2+
3+
import java.util.List;
4+
import java.util.Objects;
5+
import java.util.function.Function;
6+
import java.util.function.Supplier;
7+
8+
import ai.timefold.solver.core.impl.score.stream.collector.AbstractToListSlot;
9+
import ai.timefold.solver.core.preview.api.move.SolutionView;
10+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.BiNeighborhoodsCollector;
11+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.BiNeighborhoodsCollectorAccumulator;
12+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.BiNeighborhoodsCollectorValueHandle;
13+
import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper;
14+
15+
import org.jspecify.annotations.NullMarked;
16+
import org.jspecify.annotations.Nullable;
17+
18+
@NullMarked
19+
public final class BiToListNeighborhoodsCollector<Solution_, A, B, Mapped_>
20+
implements BiNeighborhoodsCollector<Solution_, A, B, AbstractToListSlot.State<Mapped_>, List<Mapped_>> {
21+
22+
private final BiNeighborhoodsMapper<Solution_, A, B, Mapped_> mapper;
23+
24+
private BiToListNeighborhoodsCollector(BiNeighborhoodsMapper<Solution_, A, B, Mapped_> mapper) {
25+
this.mapper = Objects.requireNonNull(mapper);
26+
}
27+
28+
public static <Solution_, A, B, Mapped_> BiToListNeighborhoodsCollector<Solution_, A, B, Mapped_>
29+
create(BiNeighborhoodsMapper<Solution_, A, B, Mapped_> mapper) {
30+
return new BiToListNeighborhoodsCollector<>(mapper);
31+
}
32+
33+
@Override
34+
public Supplier<AbstractToListSlot.State<Mapped_>> supplier() {
35+
return AbstractToListSlot.State::new;
36+
}
37+
38+
@Override
39+
public BiNeighborhoodsCollectorAccumulator<Solution_, A, B, AbstractToListSlot.State<Mapped_>> accumulator() {
40+
return (view, state) -> new Slot<>(state, mapper, view);
41+
}
42+
43+
@Override
44+
public Function<AbstractToListSlot.State<Mapped_>, @Nullable List<Mapped_>> finisher() {
45+
return AbstractToListSlot.State::result;
46+
}
47+
48+
@Override
49+
public boolean equals(Object object) {
50+
if (this == object)
51+
return true;
52+
if (object == null || getClass() != object.getClass())
53+
return false;
54+
var that = (BiToListNeighborhoodsCollector<?, ?, ?, ?>) object;
55+
return Objects.equals(mapper, that.mapper);
56+
}
57+
58+
@Override
59+
public int hashCode() {
60+
return Objects.hash(mapper);
61+
}
62+
63+
private static final class Slot<Solution_, A, B, Mapped_> extends AbstractToListSlot<Mapped_>
64+
implements BiNeighborhoodsCollectorValueHandle<A, B> {
65+
66+
private final BiNeighborhoodsMapper<Solution_, A, B, Mapped_> mapper;
67+
private final SolutionView<Solution_> view;
68+
69+
Slot(AbstractToListSlot.State<Mapped_> state, BiNeighborhoodsMapper<Solution_, A, B, Mapped_> mapper,
70+
SolutionView<Solution_> view) {
71+
super(state);
72+
this.mapper = mapper;
73+
this.view = view;
74+
}
75+
76+
@Override
77+
public void add(@Nullable A a, @Nullable B b) {
78+
addMapped(mapper.apply(view, a, b));
79+
}
80+
81+
@Override
82+
public void replaceWith(@Nullable A a, @Nullable B b) {
83+
replaceWithMapped(mapper.apply(view, a, b));
84+
}
85+
86+
@Override
87+
public void remove() {
88+
removeMapped();
89+
}
90+
}
91+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package ai.timefold.solver.core.impl.neighborhood.stream.collector;
2+
3+
import java.util.function.BiFunction;
4+
import java.util.function.Function;
5+
import java.util.function.Supplier;
6+
7+
import ai.timefold.solver.core.api.function.TriFunction;
8+
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollector;
9+
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollectorAccumulator;
10+
import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollectorValueHandle;
11+
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintCollector;
12+
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintCollectorAccumulator;
13+
import ai.timefold.solver.core.api.score.stream.uni.UniConstraintCollectorValueHandle;
14+
import ai.timefold.solver.core.preview.api.move.SolutionView;
15+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.BiNeighborhoodsCollector;
16+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.BiNeighborhoodsCollectorValueHandle;
17+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.UniNeighborhoodsCollector;
18+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.UniNeighborhoodsCollectorValueHandle;
19+
20+
import org.jspecify.annotations.NullMarked;
21+
import org.jspecify.annotations.Nullable;
22+
23+
/**
24+
* Adapts neighborhoods collectors to Bavet's constraint collector API.
25+
* Called at {@code buildNode()} time once the {@link SolutionView} is available.
26+
*/
27+
@NullMarked
28+
public final class NeighborhoodsCollectorUtils {
29+
30+
public static <Solution_, A, ResultContainer_, Result_> UniConstraintCollector<A, ResultContainer_, Result_>
31+
toConstraintCollector(
32+
UniNeighborhoodsCollector<Solution_, A, ResultContainer_, Result_> collector,
33+
SolutionView<Solution_> view) {
34+
var acc = collector.accumulator();
35+
return new UniConstraintCollector<>() {
36+
@Override
37+
public Supplier<ResultContainer_> supplier() {
38+
return collector.supplier();
39+
}
40+
41+
@Override
42+
public BiFunction<ResultContainer_, A, Runnable> accumulator() {
43+
return (UniConstraintCollectorAccumulator<ResultContainer_, A>) container -> {
44+
var handle = acc.intoGroup(view, container);
45+
return wrapUni(handle);
46+
};
47+
}
48+
49+
@Override
50+
public Function<ResultContainer_, @Nullable Result_> finisher() {
51+
return collector.finisher();
52+
}
53+
};
54+
}
55+
56+
public static <Solution_, A, B, ResultContainer_, Result_> BiConstraintCollector<A, B, ResultContainer_, Result_>
57+
toConstraintCollector(
58+
BiNeighborhoodsCollector<Solution_, A, B, ResultContainer_, Result_> collector,
59+
SolutionView<Solution_> view) {
60+
var acc = collector.accumulator();
61+
return new BiConstraintCollector<>() {
62+
@Override
63+
public Supplier<ResultContainer_> supplier() {
64+
return collector.supplier();
65+
}
66+
67+
@Override
68+
public TriFunction<ResultContainer_, A, B, Runnable> accumulator() {
69+
return (BiConstraintCollectorAccumulator<ResultContainer_, A, B>) container -> {
70+
var handle = acc.intoGroup(view, container);
71+
return wrapBi(handle);
72+
};
73+
}
74+
75+
@Override
76+
public Function<ResultContainer_, @Nullable Result_> finisher() {
77+
return collector.finisher();
78+
}
79+
};
80+
}
81+
82+
private static <A> UniConstraintCollectorValueHandle<A> wrapUni(UniNeighborhoodsCollectorValueHandle<A> handle) {
83+
return new UniConstraintCollectorValueHandle<>() {
84+
@Override
85+
public void add(@Nullable A a) {
86+
handle.add(a);
87+
}
88+
89+
@Override
90+
public void replaceWith(@Nullable A a) {
91+
handle.replaceWith(a);
92+
}
93+
94+
@Override
95+
public void remove() {
96+
handle.remove();
97+
}
98+
};
99+
}
100+
101+
private static <A, B> BiConstraintCollectorValueHandle<A, B> wrapBi(BiNeighborhoodsCollectorValueHandle<A, B> handle) {
102+
return new BiConstraintCollectorValueHandle<>() {
103+
@Override
104+
public void add(@Nullable A a, @Nullable B b) {
105+
handle.add(a, b);
106+
}
107+
108+
@Override
109+
public void replaceWith(@Nullable A a, @Nullable B b) {
110+
handle.replaceWith(a, b);
111+
}
112+
113+
@Override
114+
public void remove() {
115+
handle.remove();
116+
}
117+
};
118+
}
119+
120+
private NeighborhoodsCollectorUtils() {
121+
}
122+
123+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package ai.timefold.solver.core.impl.neighborhood.stream.collector;
2+
3+
import java.util.List;
4+
import java.util.function.Function;
5+
import java.util.function.Supplier;
6+
7+
import ai.timefold.solver.core.impl.score.stream.collector.AbstractToListSlot;
8+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.UniNeighborhoodsCollector;
9+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.UniNeighborhoodsCollectorAccumulator;
10+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.UniNeighborhoodsCollectorValueHandle;
11+
12+
import org.jspecify.annotations.NullMarked;
13+
import org.jspecify.annotations.Nullable;
14+
15+
@NullMarked
16+
public final class UniToListNeighborhoodsCollector<Solution_, A>
17+
implements UniNeighborhoodsCollector<Solution_, A, AbstractToListSlot.State<A>, List<A>> {
18+
19+
private static final UniToListNeighborhoodsCollector<?, ?> INSTANCE = new UniToListNeighborhoodsCollector<>();
20+
21+
@SuppressWarnings("unchecked")
22+
public static <Solution_, A> UniToListNeighborhoodsCollector<Solution_, A> create() {
23+
return (UniToListNeighborhoodsCollector<Solution_, A>) INSTANCE;
24+
}
25+
26+
private UniToListNeighborhoodsCollector() {
27+
}
28+
29+
@Override
30+
public Supplier<AbstractToListSlot.State<A>> supplier() {
31+
return AbstractToListSlot.State::new;
32+
}
33+
34+
@Override
35+
public UniNeighborhoodsCollectorAccumulator<Solution_, A, AbstractToListSlot.State<A>> accumulator() {
36+
return (view, state) -> new Slot<>(state);
37+
}
38+
39+
@Override
40+
public Function<AbstractToListSlot.State<A>, @Nullable List<A>> finisher() {
41+
return AbstractToListSlot.State::result;
42+
}
43+
44+
private static final class Slot<A> extends AbstractToListSlot<A>
45+
implements UniNeighborhoodsCollectorValueHandle<A> {
46+
47+
Slot(AbstractToListSlot.State<A> state) {
48+
super(state);
49+
}
50+
51+
@Override
52+
public void add(@Nullable A a) {
53+
addMapped(a);
54+
}
55+
56+
@Override
57+
public void replaceWith(@Nullable A a) {
58+
replaceWithMapped(a);
59+
}
60+
61+
@Override
62+
public void remove() {
63+
removeMapped();
64+
}
65+
}
66+
}

core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/AbstractBiEnumeratingStream.java

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.bi;
22

3-
import java.util.function.BiFunction;
4-
5-
import ai.timefold.solver.core.impl.bavet.bi.Group2Mapping0CollectorBiNode;
6-
import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor;
73
import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple;
4+
import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple;
85
import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory;
96
import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractEnumeratingStream;
7+
import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor;
108
import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeBiEnumeratingStream;
119
import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeUniEnumeratingStream;
10+
import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni.AbstractUniEnumeratingStream;
1211
import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
12+
import ai.timefold.solver.core.preview.api.neighborhood.stream.collector.BiNeighborhoodsCollector;
1313
import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.BiEnumeratingStream;
1414
import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.UniEnumeratingStream;
1515
import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper;
@@ -36,15 +36,41 @@ public final BiEnumeratingStream<Solution_, A, B> filter(BiNeighborhoodsPredicat
3636
return shareAndAddChild(new FilterBiEnumeratingStream<>(enumeratingStreamFactory, this, filter));
3737
}
3838

39-
protected <GroupKeyA_, GroupKeyB_> AbstractBiEnumeratingStream<Solution_, GroupKeyA_, GroupKeyB_>
40-
groupBy(BiFunction<A, B, GroupKeyA_> groupKeyAMapping, BiFunction<A, B, GroupKeyB_> groupKeyBMapping) {
41-
GroupNodeConstructor<BiTuple<GroupKeyA_, GroupKeyB_>> nodeConstructor =
42-
GroupNodeConstructor.twoKeysGroupBy(groupKeyAMapping, groupKeyBMapping, Group2Mapping0CollectorBiNode::new);
43-
return buildBiGroupBy(nodeConstructor);
39+
@Override
40+
public <GroupKey_> AbstractUniEnumeratingStream<Solution_, GroupKey_> groupBy(
41+
BiNeighborhoodsMapper<Solution_, A, B, GroupKey_> key) {
42+
return buildUniGroupBy(NeighborhoodsGroupNodeConstructor.biOneKeyGroupBy(key));
43+
}
44+
45+
@Override
46+
public <Result_> AbstractUniEnumeratingStream<Solution_, Result_> groupBy(
47+
BiNeighborhoodsCollector<Solution_, A, B, ?, Result_> collector) {
48+
return buildUniGroupBy(NeighborhoodsGroupNodeConstructor.biZeroKeysGroupBy(collector));
49+
}
50+
51+
@Override
52+
public <GroupKeyA_, GroupKeyB_> AbstractBiEnumeratingStream<Solution_, GroupKeyA_, GroupKeyB_> groupBy(
53+
BiNeighborhoodsMapper<Solution_, A, B, GroupKeyA_> keyA,
54+
BiNeighborhoodsMapper<Solution_, A, B, GroupKeyB_> keyB) {
55+
return buildBiGroupBy(NeighborhoodsGroupNodeConstructor.biTwoKeysGroupBy(keyA, keyB));
56+
}
57+
58+
@Override
59+
public <GroupKey_, Result_> AbstractBiEnumeratingStream<Solution_, GroupKey_, Result_> groupBy(
60+
BiNeighborhoodsMapper<Solution_, A, B, GroupKey_> key,
61+
BiNeighborhoodsCollector<Solution_, A, B, ?, Result_> collector) {
62+
return buildBiGroupBy(NeighborhoodsGroupNodeConstructor.biOneKeyAndCollectorGroupBy(key, collector));
63+
}
64+
65+
private <NewA> AbstractUniEnumeratingStream<Solution_, NewA> buildUniGroupBy(
66+
NeighborhoodsGroupNodeConstructor<Solution_, UniTuple<NewA>> nodeConstructor) {
67+
var stream = shareAndAddChild(new BiGroupUniEnumeratingStream<>(enumeratingStreamFactory, this, nodeConstructor));
68+
return enumeratingStreamFactory.share(new AftBridgeUniEnumeratingStream<>(enumeratingStreamFactory, stream),
69+
stream::setAftBridge);
4470
}
4571

46-
private <NewA, NewB> AbstractBiEnumeratingStream<Solution_, NewA, NewB>
47-
buildBiGroupBy(GroupNodeConstructor<BiTuple<NewA, NewB>> nodeConstructor) {
72+
private <NewA, NewB> AbstractBiEnumeratingStream<Solution_, NewA, NewB> buildBiGroupBy(
73+
NeighborhoodsGroupNodeConstructor<Solution_, BiTuple<NewA, NewB>> nodeConstructor) {
4874
var stream = shareAndAddChild(new BiGroupBiEnumeratingStream<>(enumeratingStreamFactory, this, nodeConstructor));
4975
return enumeratingStreamFactory.share(new AftBridgeBiEnumeratingStream<>(enumeratingStreamFactory, stream),
5076
stream::setAftBridge);
@@ -71,7 +97,7 @@ public AbstractBiEnumeratingStream<Solution_, A, B> distinct() {
7197
if (guaranteesDistinct()) {
7298
return this; // Already distinct, no need to create a new stream.
7399
}
74-
return groupBy(ConstantLambdaUtils.biPickFirst(), ConstantLambdaUtils.biPickSecond());
100+
return groupBy(ConstantLambdaUtils.neighborhoodsBiPickFirst(), ConstantLambdaUtils.neighborhoodsBiPickSecond());
75101
}
76102

77103
}

0 commit comments

Comments
 (0)