diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/AbstractBiEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/AbstractBiEnumeratingStream.java index 62984633f60..2f2f2cf0705 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/AbstractBiEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/AbstractBiEnumeratingStream.java @@ -1,17 +1,17 @@ package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.bi; -import java.util.function.BiFunction; - -import ai.timefold.solver.core.impl.bavet.bi.Group2Mapping0CollectorBiNode; -import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractEnumeratingStream; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeBiEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeUniEnumeratingStream; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni.AbstractUniEnumeratingStream; import ai.timefold.solver.core.impl.util.ConstantLambdaUtils; import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.BiEnumeratingStream; import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.UniEnumeratingStream; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsPredicate; @@ -36,15 +36,40 @@ public final BiEnumeratingStream filter(BiNeighborhoodsPredicat return shareAndAddChild(new FilterBiEnumeratingStream<>(enumeratingStreamFactory, this, filter)); } - protected AbstractBiEnumeratingStream - groupBy(BiFunction groupKeyAMapping, BiFunction groupKeyBMapping) { - GroupNodeConstructor> nodeConstructor = - GroupNodeConstructor.twoKeysGroupBy(groupKeyAMapping, groupKeyBMapping, Group2Mapping0CollectorBiNode::new); - return buildBiGroupBy(nodeConstructor); + @Override + public AbstractUniEnumeratingStream groupBy( + BiNeighborhoodsMapper key) { + return buildUniGroupBy(NeighborhoodsGroupNodeConstructor.biOneKeyGroupBy(key)); + } + + @Override + public AbstractUniEnumeratingStream groupBy( + BiNeighborhoodsCollector collector) { + return buildUniGroupBy(NeighborhoodsGroupNodeConstructor.biZeroKeysGroupBy(collector)); + } + + private AbstractBiEnumeratingStream groupBy( + BiNeighborhoodsMapper keyA, + BiNeighborhoodsMapper keyB) { + return buildBiGroupBy(NeighborhoodsGroupNodeConstructor.biTwoKeysGroupBy(keyA, keyB)); + } + + @Override + public AbstractBiEnumeratingStream groupBy( + BiNeighborhoodsMapper key, + BiNeighborhoodsCollector collector) { + return buildBiGroupBy(NeighborhoodsGroupNodeConstructor.biOneKeyAndCollectorGroupBy(key, collector)); + } + + private AbstractUniEnumeratingStream buildUniGroupBy( + NeighborhoodsGroupNodeConstructor> nodeConstructor) { + var stream = shareAndAddChild(new BiGroupUniEnumeratingStream<>(enumeratingStreamFactory, this, nodeConstructor)); + return enumeratingStreamFactory.share(new AftBridgeUniEnumeratingStream<>(enumeratingStreamFactory, stream), + stream::setAftBridge); } - private AbstractBiEnumeratingStream - buildBiGroupBy(GroupNodeConstructor> nodeConstructor) { + private AbstractBiEnumeratingStream buildBiGroupBy( + NeighborhoodsGroupNodeConstructor> nodeConstructor) { var stream = shareAndAddChild(new BiGroupBiEnumeratingStream<>(enumeratingStreamFactory, this, nodeConstructor)); return enumeratingStreamFactory.share(new AftBridgeBiEnumeratingStream<>(enumeratingStreamFactory, stream), stream::setAftBridge); @@ -71,7 +96,7 @@ public AbstractBiEnumeratingStream distinct() { if (guaranteesDistinct()) { return this; // Already distinct, no need to create a new stream. } - return groupBy(ConstantLambdaUtils.biPickFirst(), ConstantLambdaUtils.biPickSecond()); + return groupBy(ConstantLambdaUtils.neighborhoodsBiPickFirst(), ConstantLambdaUtils.neighborhoodsBiPickSecond()); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupBiEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupBiEnumeratingStream.java index 11d61291e6b..576fa437210 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupBiEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupBiEnumeratingStream.java @@ -2,10 +2,10 @@ import java.util.Objects; -import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeBiEnumeratingStream; import org.jspecify.annotations.NullMarked; @@ -15,25 +15,30 @@ final class BiGroupBiEnumeratingStream extends AbstractBiEnumeratingStream { - private final GroupNodeConstructor> nodeConstructor; + private final NeighborhoodsGroupNodeConstructor> nodeConstructor; private @Nullable AftBridgeBiEnumeratingStream aftStream; - public BiGroupBiEnumeratingStream(EnumeratingStreamFactory enumeratingStreamFactory, + BiGroupBiEnumeratingStream(EnumeratingStreamFactory enumeratingStreamFactory, AbstractBiEnumeratingStream parent, - GroupNodeConstructor> nodeConstructor) { + NeighborhoodsGroupNodeConstructor> nodeConstructor) { super(enumeratingStreamFactory, parent); - this.nodeConstructor = nodeConstructor; + this.nodeConstructor = Objects.requireNonNull(nodeConstructor); } - public void setAftBridge(AftBridgeBiEnumeratingStream aftStream) { + void setAftBridge(AftBridgeBiEnumeratingStream aftStream) { this.aftStream = aftStream; } + @Override + public boolean guaranteesDistinct() { + return true; + } + @Override public void buildNode(DataNodeBuildHelper buildHelper) { - var aftStreamChildList = aftStream.getChildStreamList(); - nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, aftStreamChildList, this, - enumeratingStreamFactory.getEnvironmentMode()); + var view = buildHelper.getSessionContext().solutionView(); + nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, + aftStream.getChildStreamList(), this, enumeratingStreamFactory.getEnvironmentMode(), view); } @Override @@ -53,7 +58,6 @@ public int hashCode() { @Override public String toString() { - return "BiGroup()"; + return "BiGroupBi()"; } - } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupUniEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupUniEnumeratingStream.java new file mode 100644 index 00000000000..120a2065e8d --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/bi/BiGroupUniEnumeratingStream.java @@ -0,0 +1,63 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.bi; + +import java.util.Objects; + +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeUniEnumeratingStream; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +final class BiGroupUniEnumeratingStream + extends AbstractBiEnumeratingStream { + + private final NeighborhoodsGroupNodeConstructor> nodeConstructor; + private @Nullable AftBridgeUniEnumeratingStream aftStream; + + BiGroupUniEnumeratingStream(EnumeratingStreamFactory enumeratingStreamFactory, + AbstractBiEnumeratingStream parent, + NeighborhoodsGroupNodeConstructor> nodeConstructor) { + super(enumeratingStreamFactory, parent); + this.nodeConstructor = Objects.requireNonNull(nodeConstructor); + } + + void setAftBridge(AftBridgeUniEnumeratingStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public boolean guaranteesDistinct() { + return true; + } + + @Override + public void buildNode(DataNodeBuildHelper buildHelper) { + var view = buildHelper.getSessionContext().solutionView(); + nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, + aftStream.getChildStreamList(), this, enumeratingStreamFactory.getEnvironmentMode(), view); + } + + @Override + public boolean equals(Object object) { + if (this == object) + return true; + if (object == null || getClass() != object.getClass()) + return false; + var that = (BiGroupUniEnumeratingStream) object; + return Objects.equals(parent, that.parent) && Objects.equals(nodeConstructor, that.nodeConstructor); + } + + @Override + public int hashCode() { + return Objects.hash(parent, nodeConstructor); + } + + @Override + public String toString() { + return "BiGroupUni()"; + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/AndThenBiNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/AndThenBiNeighborhoodsCollector.java new file mode 100644 index 00000000000..fcf3c315047 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/AndThenBiNeighborhoodsCollector.java @@ -0,0 +1,58 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorAccumulator; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public final class AndThenBiNeighborhoodsCollector + implements BiNeighborhoodsCollector { + + private final BiNeighborhoodsCollector delegate; + private final Function<@Nullable Intermediate_, @Nullable Result_> mappingFunction; + + public AndThenBiNeighborhoodsCollector( + BiNeighborhoodsCollector delegate, + Function mappingFunction) { + this.delegate = Objects.requireNonNull(delegate); + this.mappingFunction = Objects.requireNonNull(mappingFunction); + } + + @Override + public Supplier supplier() { + return delegate.supplier(); + } + + @Override + public BiNeighborhoodsCollectorAccumulator accumulator() { + return delegate.accumulator(); + } + + @Override + public Function finisher() { + var finisher = delegate.finisher(); + return container -> mappingFunction.apply(finisher.apply(container)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o instanceof AndThenBiNeighborhoodsCollector other + && Objects.equals(delegate, other.delegate) + && Objects.equals(mappingFunction, other.mappingFunction); + } + + @Override + public int hashCode() { + return Objects.hash(delegate, mappingFunction); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/AndThenUniNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/AndThenUniNeighborhoodsCollector.java new file mode 100644 index 00000000000..59305ec4eca --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/AndThenUniNeighborhoodsCollector.java @@ -0,0 +1,58 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorAccumulator; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public final class AndThenUniNeighborhoodsCollector + implements UniNeighborhoodsCollector { + + private final UniNeighborhoodsCollector delegate; + private final Function<@Nullable Intermediate_, @Nullable Result_> mappingFunction; + + public AndThenUniNeighborhoodsCollector( + UniNeighborhoodsCollector delegate, + Function mappingFunction) { + this.delegate = Objects.requireNonNull(delegate); + this.mappingFunction = Objects.requireNonNull(mappingFunction); + } + + @Override + public Supplier supplier() { + return delegate.supplier(); + } + + @Override + public UniNeighborhoodsCollectorAccumulator accumulator() { + return delegate.accumulator(); + } + + @Override + public Function finisher() { + var finisher = delegate.finisher(); + return container -> mappingFunction.apply(finisher.apply(container)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + return o instanceof AndThenUniNeighborhoodsCollector other + && Objects.equals(delegate, other.delegate) + && Objects.equals(mappingFunction, other.mappingFunction); + } + + @Override + public int hashCode() { + return Objects.hash(delegate, mappingFunction); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ComposeTwoBiNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ComposeTwoBiNeighborhoodsCollector.java new file mode 100644 index 00000000000..33c3cddac77 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ComposeTwoBiNeighborhoodsCollector.java @@ -0,0 +1,106 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.impl.util.Pair; +import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorAccumulator; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorValueHandle; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public final class ComposeTwoBiNeighborhoodsCollector + implements BiNeighborhoodsCollector, Result_> { + + private final BiNeighborhoodsCollector first; + private final BiNeighborhoodsCollector second; + private final BiFunction<@Nullable Result1_, @Nullable Result2_, @Nullable Result_> composeFunction; + + private final Supplier firstSupplier; + private final Supplier secondSupplier; + private final BiNeighborhoodsCollectorAccumulator firstAccumulator; + private final BiNeighborhoodsCollectorAccumulator secondAccumulator; + private final Function firstFinisher; + private final Function secondFinisher; + + public ComposeTwoBiNeighborhoodsCollector( + BiNeighborhoodsCollector first, + BiNeighborhoodsCollector second, + BiFunction composeFunction) { + this.first = first; + this.second = second; + this.composeFunction = composeFunction; + this.firstSupplier = first.supplier(); + this.secondSupplier = second.supplier(); + this.firstAccumulator = first.accumulator(); + this.secondAccumulator = second.accumulator(); + this.firstFinisher = first.finisher(); + this.secondFinisher = second.finisher(); + } + + @Override + public Supplier> supplier() { + return () -> new Pair<>(firstSupplier.get(), secondSupplier.get()); + } + + @Override + public BiNeighborhoodsCollectorAccumulator> accumulator() { + return ValueHandle::new; + } + + @Override + public Function, @Nullable Result_> finisher() { + return pair -> composeFunction.apply(firstFinisher.apply(pair.key()), secondFinisher.apply(pair.value())); + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + return object instanceof ComposeTwoBiNeighborhoodsCollector other + && Objects.equals(first, other.first) + && Objects.equals(second, other.second) + && Objects.equals(composeFunction, other.composeFunction); + } + + @Override + public int hashCode() { + return Objects.hash(first, second, composeFunction); + } + + private final class ValueHandle implements BiNeighborhoodsCollectorValueHandle { + + private final BiNeighborhoodsCollectorValueHandle v1; + private final BiNeighborhoodsCollectorValueHandle v2; + + ValueHandle(SolutionView view, Pair container) { + this.v1 = firstAccumulator.intoGroup(view, container.key()); + this.v2 = secondAccumulator.intoGroup(view, container.value()); + } + + @Override + public void add(@Nullable A a, @Nullable B b) { + v1.add(a, b); + v2.add(a, b); + } + + @Override + public void replaceWith(@Nullable A a, @Nullable B b) { + v1.replaceWith(a, b); + v2.replaceWith(a, b); + } + + @Override + public void remove() { + v1.remove(); + v2.remove(); + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ComposeTwoUniNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ComposeTwoUniNeighborhoodsCollector.java new file mode 100644 index 00000000000..641b5617774 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ComposeTwoUniNeighborhoodsCollector.java @@ -0,0 +1,106 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.impl.util.Pair; +import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorAccumulator; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorValueHandle; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public final class ComposeTwoUniNeighborhoodsCollector + implements UniNeighborhoodsCollector, Result_> { + + private final UniNeighborhoodsCollector first; + private final UniNeighborhoodsCollector second; + private final BiFunction<@Nullable Result1_, @Nullable Result2_, @Nullable Result_> composeFunction; + + private final Supplier firstSupplier; + private final Supplier secondSupplier; + private final UniNeighborhoodsCollectorAccumulator firstAccumulator; + private final UniNeighborhoodsCollectorAccumulator secondAccumulator; + private final Function firstFinisher; + private final Function secondFinisher; + + public ComposeTwoUniNeighborhoodsCollector( + UniNeighborhoodsCollector first, + UniNeighborhoodsCollector second, + BiFunction composeFunction) { + this.first = first; + this.second = second; + this.composeFunction = composeFunction; + this.firstSupplier = first.supplier(); + this.secondSupplier = second.supplier(); + this.firstAccumulator = first.accumulator(); + this.secondAccumulator = second.accumulator(); + this.firstFinisher = first.finisher(); + this.secondFinisher = second.finisher(); + } + + @Override + public Supplier> supplier() { + return () -> new Pair<>(firstSupplier.get(), secondSupplier.get()); + } + + @Override + public UniNeighborhoodsCollectorAccumulator> accumulator() { + return ValueHandle::new; + } + + @Override + public Function, @Nullable Result_> finisher() { + return pair -> composeFunction.apply(firstFinisher.apply(pair.key()), secondFinisher.apply(pair.value())); + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + return object instanceof ComposeTwoUniNeighborhoodsCollector other + && Objects.equals(first, other.first) + && Objects.equals(second, other.second) + && Objects.equals(composeFunction, other.composeFunction); + } + + @Override + public int hashCode() { + return Objects.hash(first, second, composeFunction); + } + + private final class ValueHandle implements UniNeighborhoodsCollectorValueHandle { + + private final UniNeighborhoodsCollectorValueHandle v1; + private final UniNeighborhoodsCollectorValueHandle v2; + + ValueHandle(SolutionView view, Pair container) { + this.v1 = firstAccumulator.intoGroup(view, container.key()); + this.v2 = secondAccumulator.intoGroup(view, container.value()); + } + + @Override + public void add(@Nullable A a) { + v1.add(a); + v2.add(a); + } + + @Override + public void replaceWith(@Nullable A a) { + v1.replaceWith(a); + v2.replaceWith(a); + } + + @Override + public void remove() { + v1.remove(); + v2.remove(); + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/NeighborhoodsCollectorUtils.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/NeighborhoodsCollectorUtils.java new file mode 100644 index 00000000000..5c9104987b7 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/NeighborhoodsCollectorUtils.java @@ -0,0 +1,123 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.api.function.TriFunction; +import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollector; +import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollectorAccumulator; +import ai.timefold.solver.core.api.score.stream.bi.BiConstraintCollectorValueHandle; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintCollector; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintCollectorAccumulator; +import ai.timefold.solver.core.api.score.stream.uni.UniConstraintCollectorValueHandle; +import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorValueHandle; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorValueHandle; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +/** + * Adapts neighborhoods collectors to Bavet's constraint collector API. + * Called at {@code buildNode()} time once the {@link SolutionView} is available. + */ +@NullMarked +public final class NeighborhoodsCollectorUtils { + + public static UniConstraintCollector + toConstraintCollector( + UniNeighborhoodsCollector collector, + SolutionView view) { + var acc = collector.accumulator(); + return new UniConstraintCollector<>() { + @Override + public Supplier supplier() { + return collector.supplier(); + } + + @Override + public BiFunction accumulator() { + return (UniConstraintCollectorAccumulator) container -> { + var handle = acc.intoGroup(view, container); + return wrapUni(handle); + }; + } + + @Override + public Function finisher() { + return collector.finisher(); + } + }; + } + + public static BiConstraintCollector + toConstraintCollector( + BiNeighborhoodsCollector collector, + SolutionView view) { + var acc = collector.accumulator(); + return new BiConstraintCollector<>() { + @Override + public Supplier supplier() { + return collector.supplier(); + } + + @Override + public TriFunction accumulator() { + return (BiConstraintCollectorAccumulator) container -> { + var handle = acc.intoGroup(view, container); + return wrapBi(handle); + }; + } + + @Override + public Function finisher() { + return collector.finisher(); + } + }; + } + + private static UniConstraintCollectorValueHandle wrapUni(UniNeighborhoodsCollectorValueHandle handle) { + return new UniConstraintCollectorValueHandle<>() { + @Override + public void add(@Nullable A a) { + handle.add(a); + } + + @Override + public void replaceWith(@Nullable A a) { + handle.replaceWith(a); + } + + @Override + public void remove() { + handle.remove(); + } + }; + } + + private static BiConstraintCollectorValueHandle wrapBi(BiNeighborhoodsCollectorValueHandle handle) { + return new BiConstraintCollectorValueHandle<>() { + @Override + public void add(@Nullable A a, @Nullable B b) { + handle.add(a, b); + } + + @Override + public void replaceWith(@Nullable A a, @Nullable B b) { + handle.replaceWith(a, b); + } + + @Override + public void remove() { + handle.remove(); + } + }; + } + + private NeighborhoodsCollectorUtils() { + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ToListBiNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ToListBiNeighborhoodsCollector.java new file mode 100644 index 00000000000..d75257a89e6 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ToListBiNeighborhoodsCollector.java @@ -0,0 +1,88 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.impl.score.stream.collector.AbstractToListSlot; +import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorAccumulator; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorValueHandle; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public final class ToListBiNeighborhoodsCollector + implements BiNeighborhoodsCollector, List> { + + private final BiNeighborhoodsMapper mapper; + + private ToListBiNeighborhoodsCollector(BiNeighborhoodsMapper mapper) { + this.mapper = Objects.requireNonNull(mapper); + } + + public static ToListBiNeighborhoodsCollector + create(BiNeighborhoodsMapper mapper) { + return new ToListBiNeighborhoodsCollector<>(mapper); + } + + @Override + public Supplier> supplier() { + return AbstractToListSlot.State::new; + } + + @Override + public BiNeighborhoodsCollectorAccumulator> accumulator() { + return (view, state) -> new Slot(state, view); + } + + @Override + public Function, @Nullable List> finisher() { + return AbstractToListSlot.State::result; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + return object instanceof ToListBiNeighborhoodsCollector other + && Objects.equals(mapper, other.mapper); + } + + @Override + public int hashCode() { + return Objects.hash(mapper); + } + + private final class Slot + extends AbstractToListSlot + implements BiNeighborhoodsCollectorValueHandle { + + private final SolutionView view; + + Slot(AbstractToListSlot.State state, SolutionView view) { + super(state); + this.view = view; + } + + @Override + public void add(@Nullable A a, @Nullable B b) { + addMapped(mapper.apply(view, a, b)); + } + + @Override + public void replaceWith(@Nullable A a, @Nullable B b) { + replaceWithMapped(mapper.apply(view, a, b)); + } + + @Override + public void remove() { + removeMapped(); + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ToListUniNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ToListUniNeighborhoodsCollector.java new file mode 100644 index 00000000000..219e84a4db1 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/ToListUniNeighborhoodsCollector.java @@ -0,0 +1,88 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.impl.score.stream.collector.AbstractToListSlot; +import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorAccumulator; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorValueHandle; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +public final class ToListUniNeighborhoodsCollector + implements UniNeighborhoodsCollector, List> { + + private final UniNeighborhoodsMapper mapper; + + private ToListUniNeighborhoodsCollector(UniNeighborhoodsMapper mapper) { + this.mapper = Objects.requireNonNull(mapper); + } + + public static ToListUniNeighborhoodsCollector + create(UniNeighborhoodsMapper mapper) { + return new ToListUniNeighborhoodsCollector<>(mapper); + } + + @Override + public Supplier> supplier() { + return AbstractToListSlot.State::new; + } + + @Override + public UniNeighborhoodsCollectorAccumulator> accumulator() { + return (view, state) -> new Slot(state, view); + } + + @Override + public Function, @Nullable List> finisher() { + return AbstractToListSlot.State::result; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } + return object instanceof ToListUniNeighborhoodsCollector other + && Objects.equals(mapper, other.mapper); + } + + @Override + public int hashCode() { + return Objects.hash(mapper); + } + + private final class Slot + extends AbstractToListSlot + implements UniNeighborhoodsCollectorValueHandle { + + private final SolutionView view; + + Slot(State state, SolutionView view) { + super(state); + this.view = view; + } + + @Override + public void add(@Nullable A a) { + addMapped(mapper.apply(view, a)); + } + + @Override + public void replaceWith(@Nullable A a) { + replaceWithMapped(mapper.apply(view, a)); + } + + @Override + public void remove() { + removeMapped(); + } + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractNeighborhoodsGroupNodeConstructor.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractNeighborhoodsGroupNodeConstructor.java new file mode 100644 index 00000000000..66e8eb616e6 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractNeighborhoodsGroupNodeConstructor.java @@ -0,0 +1,50 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; + +import ai.timefold.solver.core.config.solver.EnvironmentMode; +import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.BavetStream; +import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; +import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.preview.api.move.SolutionView; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +abstract sealed class AbstractNeighborhoodsGroupNodeConstructor + implements NeighborhoodsGroupNodeConstructor + permits NeighborhoodsGroupNodeConstructorWithAccumulate, + NeighborhoodsGroupNodeConstructorWithoutAccumulate { + + private final Object equalityKey; + private final Function, GroupNodeConstructor> factory; + + protected AbstractNeighborhoodsGroupNodeConstructor(Object equalityKey, + Function, GroupNodeConstructor> factory) { + this.equalityKey = Objects.requireNonNull(equalityKey); + this.factory = Objects.requireNonNull(factory); + } + + @Override + public void build(AbstractNodeBuildHelper buildHelper, + Stream_ parentTupleSource, Stream_ aftStream, List aftStreamChildList, + Stream_ thisStream, EnvironmentMode environmentMode, SolutionView view) { + factory.apply(view).build(buildHelper, parentTupleSource, aftStream, + aftStreamChildList, thisStream, environmentMode); + } + + @Override + public boolean equals(Object o) { + return o instanceof AbstractNeighborhoodsGroupNodeConstructor that + && Objects.equals(getClass(), that.getClass()) + && Objects.equals(equalityKey, that.equalityKey); + } + + @Override + public int hashCode() { + return Objects.hashCode(equalityKey); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructor.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructor.java new file mode 100644 index 00000000000..4dc30aee2b5 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructor.java @@ -0,0 +1,107 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common; + +import java.util.List; + +import ai.timefold.solver.core.config.solver.EnvironmentMode; +import ai.timefold.solver.core.impl.bavet.bi.Group0Mapping1CollectorBiNode; +import ai.timefold.solver.core.impl.bavet.bi.Group1Mapping0CollectorBiNode; +import ai.timefold.solver.core.impl.bavet.bi.Group1Mapping1CollectorBiNode; +import ai.timefold.solver.core.impl.bavet.bi.Group2Mapping0CollectorBiNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.BavetStream; +import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.bavet.uni.Group0Mapping1CollectorUniNode; +import ai.timefold.solver.core.impl.bavet.uni.Group1Mapping0CollectorUniNode; +import ai.timefold.solver.core.impl.bavet.uni.Group1Mapping1CollectorUniNode; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.NeighborhoodsCollectorUtils; +import ai.timefold.solver.core.impl.util.Pair; +import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +public sealed interface NeighborhoodsGroupNodeConstructor + permits AbstractNeighborhoodsGroupNodeConstructor { + + void build(AbstractNodeBuildHelper buildHelper, + Stream_ parentTupleSource, Stream_ aftStream, List aftStreamChildList, + Stream_ thisStream, EnvironmentMode environmentMode, SolutionView view); + + static + NeighborhoodsGroupNodeConstructor> + uniOneKeyGroupBy(UniNeighborhoodsMapper key) { + return new NeighborhoodsGroupNodeConstructorWithoutAccumulate<>(key, + view -> GroupNodeConstructor.oneKeyGroupBy(key.toFunction(view), + Group1Mapping0CollectorUniNode::new)); + } + + @SuppressWarnings("unchecked") + static + NeighborhoodsGroupNodeConstructor> + uniZeroKeysGroupBy(UniNeighborhoodsCollector collector) { + return new NeighborhoodsGroupNodeConstructorWithAccumulate<>(collector, + view -> GroupNodeConstructor.zeroKeysGroupBy( + NeighborhoodsCollectorUtils.toConstraintCollector( + (UniNeighborhoodsCollector) collector, view), + Group0Mapping1CollectorUniNode::new)); + } + + @SuppressWarnings("unchecked") + static + NeighborhoodsGroupNodeConstructor> + uniOneKeyAndCollectorGroupBy(UniNeighborhoodsMapper key, + UniNeighborhoodsCollector collector) { + return new NeighborhoodsGroupNodeConstructorWithAccumulate<>(new Pair<>(key, collector), + view -> GroupNodeConstructor.oneKeyGroupBy(key.toFunction(view), + NeighborhoodsCollectorUtils.toConstraintCollector( + (UniNeighborhoodsCollector) collector, view), + Group1Mapping1CollectorUniNode::new)); + } + + static + NeighborhoodsGroupNodeConstructor> + biOneKeyGroupBy(BiNeighborhoodsMapper key) { + return new NeighborhoodsGroupNodeConstructorWithoutAccumulate<>(key, + view -> GroupNodeConstructor.oneKeyGroupBy(key.toBiFunction(view), + Group1Mapping0CollectorBiNode::new)); + } + + static + NeighborhoodsGroupNodeConstructor> + biTwoKeysGroupBy(BiNeighborhoodsMapper keyA, + BiNeighborhoodsMapper keyB) { + return new NeighborhoodsGroupNodeConstructorWithoutAccumulate<>(new Pair<>(keyA, keyB), + view -> GroupNodeConstructor.twoKeysGroupBy(keyA.toBiFunction(view), keyB.toBiFunction(view), + Group2Mapping0CollectorBiNode::new)); + } + + @SuppressWarnings("unchecked") + static + NeighborhoodsGroupNodeConstructor> + biZeroKeysGroupBy(BiNeighborhoodsCollector collector) { + return new NeighborhoodsGroupNodeConstructorWithAccumulate<>(collector, + view -> GroupNodeConstructor.zeroKeysGroupBy( + NeighborhoodsCollectorUtils.toConstraintCollector( + (BiNeighborhoodsCollector) collector, view), + Group0Mapping1CollectorBiNode::new)); + } + + @SuppressWarnings("unchecked") + static + NeighborhoodsGroupNodeConstructor> + biOneKeyAndCollectorGroupBy(BiNeighborhoodsMapper key, + BiNeighborhoodsCollector collector) { + return new NeighborhoodsGroupNodeConstructorWithAccumulate<>(new Pair<>(key, collector), + view -> GroupNodeConstructor.oneKeyGroupBy(key.toBiFunction(view), + NeighborhoodsCollectorUtils.toConstraintCollector( + (BiNeighborhoodsCollector) collector, view), + Group1Mapping1CollectorBiNode::new)); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructorWithAccumulate.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructorWithAccumulate.java new file mode 100644 index 00000000000..ce31acd31c9 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructorWithAccumulate.java @@ -0,0 +1,19 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common; + +import java.util.function.Function; + +import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; +import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.preview.api.move.SolutionView; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +final class NeighborhoodsGroupNodeConstructorWithAccumulate + extends AbstractNeighborhoodsGroupNodeConstructor { + + NeighborhoodsGroupNodeConstructorWithAccumulate(Object equalityKey, + Function, GroupNodeConstructor> factory) { + super(equalityKey, factory); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructorWithoutAccumulate.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructorWithoutAccumulate.java new file mode 100644 index 00000000000..25f7e9c9422 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/NeighborhoodsGroupNodeConstructorWithoutAccumulate.java @@ -0,0 +1,19 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common; + +import java.util.function.Function; + +import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; +import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.preview.api.move.SolutionView; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +final class NeighborhoodsGroupNodeConstructorWithoutAccumulate + extends AbstractNeighborhoodsGroupNodeConstructor { + + NeighborhoodsGroupNodeConstructorWithoutAccumulate(Object equalityKey, + Function, GroupNodeConstructor> factory) { + super(equalityKey, factory); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractUniEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractUniEnumeratingStream.java index 280482ce660..d26b062bcbc 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractUniEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractUniEnumeratingStream.java @@ -1,16 +1,12 @@ package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni; -import static ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor.oneKeyGroupBy; - -import java.util.Objects; -import java.util.function.Function; - -import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; -import ai.timefold.solver.core.impl.bavet.uni.Group1Mapping0CollectorUniNode; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.bi.AbstractBiEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.bi.JoinBiEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractEnumeratingStream; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeBiEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeUniEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.ForeBridgeUniEnumeratingStream; @@ -18,6 +14,7 @@ import ai.timefold.solver.core.impl.util.ConstantLambdaUtils; import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.BiEnumeratingStream; import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.UniEnumeratingStream; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsPredicate; import ai.timefold.solver.core.preview.api.neighborhood.stream.joiner.BiNeighborhoodsJoiner; @@ -103,31 +100,39 @@ private UniEnumeratingStream ifExistsOrNot(boolean shouldExist joinerComber.mergedJoiner(), joinerComber.mergedFiltering()), childStreamList::add); } - /** - * Convert the {@link UniEnumeratingStream} to a different {@link UniEnumeratingStream}, - * containing the set of tuples resulting from applying the group key mapping function - * on all tuples of the original stream. - * Neither tuple of the new stream {@link Objects#equals(Object, Object)} any other. - * - * @param groupKeyMapping mapping function to convert each element in the stream to a different element - * @param the type of a fact in the destination {@link UniEnumeratingStream}'s tuple; - * must honor {@link Object#hashCode() the general contract of hashCode}. - */ - protected AbstractUniEnumeratingStream groupBy(Function groupKeyMapping) { - // We do not expose this on the API, as this operation is not yet needed in any of the moves. - // The groupBy API will need revisiting if exposed as a feature of Neighborhoods API, do not expose as is. - GroupNodeConstructor> nodeConstructor = - oneKeyGroupBy(groupKeyMapping, Group1Mapping0CollectorUniNode::new); - return buildUniGroupBy(nodeConstructor); - } - - private AbstractUniEnumeratingStream - buildUniGroupBy(GroupNodeConstructor> nodeConstructor) { + @Override + public AbstractUniEnumeratingStream groupBy( + UniNeighborhoodsMapper key) { + return buildUniGroupBy(NeighborhoodsGroupNodeConstructor.uniOneKeyGroupBy(key)); + } + + @Override + public AbstractUniEnumeratingStream groupBy( + UniNeighborhoodsCollector collector) { + return buildUniGroupBy(NeighborhoodsGroupNodeConstructor.uniZeroKeysGroupBy(collector)); + } + + @Override + public AbstractBiEnumeratingStream groupBy( + UniNeighborhoodsMapper key, + UniNeighborhoodsCollector collector) { + return buildBiGroupBy(NeighborhoodsGroupNodeConstructor.uniOneKeyAndCollectorGroupBy(key, collector)); + } + + private AbstractUniEnumeratingStream buildUniGroupBy( + NeighborhoodsGroupNodeConstructor> nodeConstructor) { var stream = shareAndAddChild(new UniGroupUniEnumeratingStream<>(enumeratingStreamFactory, this, nodeConstructor)); return enumeratingStreamFactory.share(new AftBridgeUniEnumeratingStream<>(enumeratingStreamFactory, stream), stream::setAftBridge); } + private AbstractBiEnumeratingStream buildBiGroupBy( + NeighborhoodsGroupNodeConstructor> nodeConstructor) { + var stream = shareAndAddChild(new UniGroupBiEnumeratingStream<>(enumeratingStreamFactory, this, nodeConstructor)); + return enumeratingStreamFactory.share(new AftBridgeBiEnumeratingStream<>(enumeratingStreamFactory, stream), + stream::setAftBridge); + } + @Override public UniEnumeratingStream map(UniNeighborhoodsMapper mapping) { var stream = shareAndAddChild(new UniMapUniEnumeratingStream<>(enumeratingStreamFactory, this, mapping)); @@ -149,7 +154,7 @@ public AbstractUniEnumeratingStream distinct() { if (guaranteesDistinct()) { return this; // Already distinct, no need to create a new stream. } - return groupBy(ConstantLambdaUtils.identity()); + return groupBy(ConstantLambdaUtils.neighborhoodsUniPickFirst()); } public UniLeftDataset createLeftDataset() { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupBiEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupBiEnumeratingStream.java new file mode 100644 index 00000000000..0601321a219 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupBiEnumeratingStream.java @@ -0,0 +1,63 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni; + +import java.util.Objects; + +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeBiEnumeratingStream; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +@NullMarked +final class UniGroupBiEnumeratingStream + extends AbstractUniEnumeratingStream { + + private final NeighborhoodsGroupNodeConstructor> nodeConstructor; + private @Nullable AftBridgeBiEnumeratingStream aftStream; + + UniGroupBiEnumeratingStream(EnumeratingStreamFactory enumeratingStreamFactory, + AbstractUniEnumeratingStream parent, + NeighborhoodsGroupNodeConstructor> nodeConstructor) { + super(enumeratingStreamFactory, parent); + this.nodeConstructor = Objects.requireNonNull(nodeConstructor); + } + + void setAftBridge(AftBridgeBiEnumeratingStream aftStream) { + this.aftStream = aftStream; + } + + @Override + public boolean guaranteesDistinct() { + return true; + } + + @Override + public void buildNode(DataNodeBuildHelper buildHelper) { + var view = buildHelper.getSessionContext().solutionView(); + nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, + aftStream.getChildStreamList(), this, enumeratingStreamFactory.getEnvironmentMode(), view); + } + + @Override + public boolean equals(Object object) { + if (this == object) + return true; + if (object == null || getClass() != object.getClass()) + return false; + var that = (UniGroupBiEnumeratingStream) object; + return Objects.equals(parent, that.parent) && Objects.equals(nodeConstructor, that.nodeConstructor); + } + + @Override + public int hashCode() { + return Objects.hash(parent, nodeConstructor); + } + + @Override + public String toString() { + return "UniGroupBi()"; + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupUniEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupUniEnumeratingStream.java index b5dd284b630..bd935a59156 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupUniEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/UniGroupUniEnumeratingStream.java @@ -2,10 +2,10 @@ import java.util.Objects; -import ai.timefold.solver.core.impl.bavet.common.GroupNodeConstructor; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.NeighborhoodsGroupNodeConstructor; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.bridge.AftBridgeUniEnumeratingStream; import org.jspecify.annotations.NullMarked; @@ -15,35 +15,32 @@ final class UniGroupUniEnumeratingStream extends AbstractUniEnumeratingStream { - private final GroupNodeConstructor> nodeConstructor; + private final NeighborhoodsGroupNodeConstructor> nodeConstructor; private @Nullable AftBridgeUniEnumeratingStream aftStream; - public UniGroupUniEnumeratingStream(EnumeratingStreamFactory enumeratingStreamFactory, + UniGroupUniEnumeratingStream(EnumeratingStreamFactory enumeratingStreamFactory, AbstractUniEnumeratingStream parent, - GroupNodeConstructor> nodeConstructor) { + NeighborhoodsGroupNodeConstructor> nodeConstructor) { super(enumeratingStreamFactory, parent); - this.nodeConstructor = nodeConstructor; + this.nodeConstructor = Objects.requireNonNull(nodeConstructor); } - public void setAftBridge(AftBridgeUniEnumeratingStream aftStream) { + void setAftBridge(AftBridgeUniEnumeratingStream aftStream) { this.aftStream = aftStream; } - // ************************************************************************ - // Node creation - // ************************************************************************ + @Override + public boolean guaranteesDistinct() { + return true; + } @Override public void buildNode(DataNodeBuildHelper buildHelper) { - var aftStreamChildList = aftStream.getChildStreamList(); - nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, aftStreamChildList, this, - enumeratingStreamFactory.getEnvironmentMode()); + var view = buildHelper.getSessionContext().solutionView(); + nodeConstructor.build(buildHelper, parent.getTupleSource(), aftStream, + aftStream.getChildStreamList(), this, enumeratingStreamFactory.getEnvironmentMode(), view); } - // ************************************************************************ - // Equality for node sharing - // ************************************************************************ - @Override public boolean equals(Object object) { if (this == object) @@ -61,7 +58,6 @@ public int hashCode() { @Override public String toString() { - return "UniGroup()"; + return "UniGroupUni()"; } - } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/util/ConstantLambdaUtils.java b/core/src/main/java/ai/timefold/solver/core/impl/util/ConstantLambdaUtils.java index 8e40d04986e..9695f5c4cc3 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/util/ConstantLambdaUtils.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/util/ConstantLambdaUtils.java @@ -1,34 +1,31 @@ package ai.timefold.solver.core.impl.util; +import java.lang.invoke.ConstantCallSite; import java.math.BigDecimal; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.BiPredicate; import java.util.function.Function; -import java.util.function.ToIntBiFunction; -import java.util.function.ToIntFunction; import java.util.function.ToLongBiFunction; import java.util.function.ToLongFunction; import ai.timefold.solver.core.api.function.QuadFunction; -import ai.timefold.solver.core.api.function.ToIntQuadFunction; -import ai.timefold.solver.core.api.function.ToIntTriFunction; import ai.timefold.solver.core.api.function.ToLongQuadFunction; import ai.timefold.solver.core.api.function.ToLongTriFunction; import ai.timefold.solver.core.api.function.TriFunction; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; /** * A class that holds common lambdas that are guaranteed to be the same across method calls. * In most JDK's, - * stateless lambdas are bound to a {@link java.lang.invoke.ConstantCallSite} inside the method that defined them, - * but that {@link java.lang.invoke.ConstantCallSite} is not shared across methods, + * stateless lambdas are bound to a {@link ConstantCallSite} inside the method that defined them, + * but that {@link ConstantCallSite} is not shared across methods, * even for methods in the same class. * Thus, when lambda reference equality is important (such as for node sharing in Constraint Streams), * the lambdas in this class should be used. */ public final class ConstantLambdaUtils { - private static final Runnable NO_OP = () -> { - }; @SuppressWarnings("rawtypes") private static final Function IDENTITY = Function.identity(); @@ -69,9 +66,6 @@ public final class ConstantLambdaUtils { @SuppressWarnings("rawtypes") private static final ToLongFunction UNI_CONSTANT_ZERO_LONG = a -> 0L; - @SuppressWarnings("rawtypes") - private static final ToIntFunction UNI_CONSTANT_ONE = a -> 1; - @SuppressWarnings("rawtypes") private static final ToLongFunction UNI_CONSTANT_ONE_LONG = a -> 1L; @@ -84,9 +78,6 @@ public final class ConstantLambdaUtils { @SuppressWarnings("rawtypes") private static final ToLongBiFunction BI_CONSTANT_ZERO_LONG = (a, b) -> 0L; - @SuppressWarnings("rawtypes") - private static final ToIntBiFunction BI_CONSTANT_ONE = (a, b) -> 1; - @SuppressWarnings("rawtypes") private static final ToLongBiFunction BI_CONSTANT_ONE_LONG = (a, b) -> 1L; @@ -99,9 +90,6 @@ public final class ConstantLambdaUtils { @SuppressWarnings("rawtypes") private static final ToLongTriFunction TRI_CONSTANT_ZERO_LONG = (a, b, c) -> 0L; - @SuppressWarnings("rawtypes") - private static final ToIntTriFunction TRI_CONSTANT_ONE = (a, b, c) -> 1; - @SuppressWarnings("rawtypes") private static final ToLongTriFunction TRI_CONSTANT_ONE_LONG = (a, b, c) -> 1L; @@ -111,15 +99,20 @@ public final class ConstantLambdaUtils { @SuppressWarnings("rawtypes") private static final ToLongQuadFunction QUAD_CONSTANT_ZERO_LONG = (a, b, c, d) -> 0L; - @SuppressWarnings("rawtypes") - private static final ToIntQuadFunction QUAD_CONSTANT_ONE = (a, b, c, d) -> 1; - @SuppressWarnings("rawtypes") private static final ToLongQuadFunction QUAD_CONSTANT_ONE_LONG = (a, b, c, d) -> 1L; @SuppressWarnings("rawtypes") private static final QuadFunction QUAD_CONSTANT_ONE_BIG_DECiMAL = (a, b, c, d) -> BigDecimal.ONE; + private static final UniNeighborhoodsMapper NEIGHBORHOODS_UNI_PICK_FIRST = (view, a) -> a; + + private static final BiNeighborhoodsMapper NEIGHBORHOODS_BI_PICK_FIRST = + (view, a, b) -> a; + + private static final BiNeighborhoodsMapper NEIGHBORHOODS_BI_PICK_SECOND = + (view, a, b) -> b; + public static Function uncheck(ThrowableFunction function) { return t -> { try { @@ -130,15 +123,6 @@ public static Function uncheck(ThrowableFunction function) { }; } - /** - * Returns a {@link Runnable} that does nothing. - * - * @return never null - */ - public static Runnable noop() { - return NO_OP; - } - /** * Returns a {@link Function} that returns its only input. * @@ -260,16 +244,6 @@ public static Function uniConstantNull() { return UNI_CONSTANT_NULL; } - /** - * Returns a {@link ToIntFunction} that returns the constant 1. - * - * @return never null - */ - @SuppressWarnings("unchecked") - public static ToIntFunction uniConstantOne() { - return UNI_CONSTANT_ONE; - } - /** * Returns a {@link ToLongFunction} that returns the constant 0. * @@ -320,16 +294,6 @@ public static ToLongBiFunction biConstantZeroLong() { return BI_CONSTANT_ZERO_LONG; } - /** - * Returns a {@link ToIntBiFunction} that returns the constant 1. - * - * @return never null - */ - @SuppressWarnings("unchecked") - public static ToIntBiFunction biConstantOne() { - return BI_CONSTANT_ONE; - } - /** * Returns a {@link ToLongBiFunction} that returns the constant 1. * @@ -370,16 +334,6 @@ public static ToLongTriFunction triConstantZeroLong() { return TRI_CONSTANT_ZERO_LONG; } - /** - * Returns a {@link ToIntTriFunction} that returns the constant 1. - * - * @return never null - */ - @SuppressWarnings("unchecked") - public static ToIntTriFunction triConstantOne() { - return TRI_CONSTANT_ONE; - } - /** * Returns a {@link ToLongTriFunction} that returns the constant 1. * @@ -410,16 +364,6 @@ public static ToLongQuadFunction quadConstantZeroLong() return QUAD_CONSTANT_ZERO_LONG; } - /** - * Returns a {@link ToIntQuadFunction} that returns the constant 1. - * - * @return never null - */ - @SuppressWarnings("unchecked") - public static ToIntQuadFunction quadConstantOne() { - return QUAD_CONSTANT_ONE; - } - /** * Returns a {@link ToLongQuadFunction} that returns the constant 1. * @@ -440,6 +384,21 @@ public static QuadFunction quadConstantOneB return QUAD_CONSTANT_ONE_BIG_DECiMAL; } + @SuppressWarnings("unchecked") + public static UniNeighborhoodsMapper neighborhoodsUniPickFirst() { + return (UniNeighborhoodsMapper) NEIGHBORHOODS_UNI_PICK_FIRST; + } + + @SuppressWarnings("unchecked") + public static BiNeighborhoodsMapper neighborhoodsBiPickFirst() { + return (BiNeighborhoodsMapper) NEIGHBORHOODS_BI_PICK_FIRST; + } + + @SuppressWarnings("unchecked") + public static BiNeighborhoodsMapper neighborhoodsBiPickSecond() { + return (BiNeighborhoodsMapper) NEIGHBORHOODS_BI_PICK_SECOND; + } + @FunctionalInterface public interface ThrowableFunction { R apply(T t) throws Throwable; diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStream.java index a2bce5b0f41..90d478d005e 100644 --- a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStream.java @@ -1,9 +1,11 @@ package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating; -import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsPredicate; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsPredicate; import org.jspecify.annotations.NullMarked; @@ -11,8 +13,7 @@ public interface BiEnumeratingStream extends EnumeratingStream { /** - * Exhaustively test each fact against the {@link BiNeighborhoodsPredicate} - * and match if {@link BiNeighborhoodsPredicate#test(SolutionView, Object, Object)} returns true. + * As defined by {@link UniEnumeratingStream#filter(UniNeighborhoodsPredicate)}. */ BiEnumeratingStream filter(BiNeighborhoodsPredicate filter); @@ -22,28 +23,38 @@ public interface BiEnumeratingStream extends EnumeratingStream /** * As defined by {@link UniEnumeratingStream#map(UniNeighborhoodsMapper)}. - * - *

- * Use with caution, - * as the increased memory allocation rates coming from tuple creation may negatively affect performance. - * - * @param mapping function to convert the original tuple into the new tuple - * @param the type of the only fact in the resulting {@link UniEnumeratingStream}'s tuple */ UniEnumeratingStream map(BiNeighborhoodsMapper mapping); /** * As defined by {@link #map(BiNeighborhoodsMapper)}, only resulting in {@link BiEnumeratingStream}. - * - * @param mappingA function to convert the original tuple into the first fact of a new tuple - * @param mappingB function to convert the original tuple into the second fact of a new tuple - * @param the type of the first fact in the resulting {@link BiEnumeratingStream}'s tuple - * @param the type of the first fact in the resulting {@link BiEnumeratingStream}'s tuple */ BiEnumeratingStream map( BiNeighborhoodsMapper mappingA, BiNeighborhoodsMapper mappingB); + /** + * As defined by {@link UniEnumeratingStream#groupBy(UniNeighborhoodsMapper)}, + * only for {@link BiEnumeratingStream} sources. + */ + UniEnumeratingStream groupBy(BiNeighborhoodsMapper key); + + /** + * As defined by + * {@link UniEnumeratingStream#groupBy(UniNeighborhoodsCollector)}, + * only for {@link BiEnumeratingStream} sources. + */ + UniEnumeratingStream groupBy(BiNeighborhoodsCollector collector); + + /** + * As defined by + * {@link UniEnumeratingStream#groupBy(UniNeighborhoodsMapper, UniNeighborhoodsCollector)}, + * only for {@link BiEnumeratingStream} sources. + */ + BiEnumeratingStream groupBy( + BiNeighborhoodsMapper key, + BiNeighborhoodsCollector collector); + /** * As defined by {@link UniEnumeratingStream#distinct()}. */ diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStream.java index 7366acdcd16..9d9db0a4b91 100644 --- a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStream.java @@ -1,6 +1,7 @@ package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating; import ai.timefold.solver.core.preview.api.move.SolutionView; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsPredicate; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsPredicate; @@ -413,21 +414,53 @@ default UniEnumeratingStream ifNotExists(Class otherClass, * * @param mapping function to convert the original tuple into the new tuple * @param the type of the only fact in the resulting {@link UniEnumeratingStream}'s tuple + * @return a {@link UniEnumeratingStream} of the new tuples created by the mapping function */ UniEnumeratingStream map(UniNeighborhoodsMapper mapping); /** - * As defined by {@link #map(UniNeighborhoodsMapper)}, only resulting in {@link BiEnumeratingStream}. - * - * @param mappingA function to convert the original tuple into the first fact of a new tuple - * @param mappingB function to convert the original tuple into the second fact of a new tuple - * @param the type of the first fact in the resulting {@link BiEnumeratingStream}'s tuple - * @param the type of the first fact in the resulting {@link BiEnumeratingStream}'s tuple + * As defined by {@link #map(UniNeighborhoodsMapper)}, + * only resulting in {@link BiEnumeratingStream}. */ BiEnumeratingStream map( UniNeighborhoodsMapper mappingA, UniNeighborhoodsMapper mappingB); + /** + * Groups the stream by a single key, producing one element (the key) per group. + * + * @param key mapping function to extract the group key from each element + * @param the type of the group key + * @return a {@link UniEnumeratingStream} where the only fact is the group key, + * and there is one tuple for each group of original tuples that share the same group key + */ + UniEnumeratingStream groupBy(UniNeighborhoodsMapper key); + + /** + * Collects the entire stream into a single group, producing one element (the collected result). + * + * @param collector the collector to apply to the stream + * @param the type of the result + * @return a {@link UniEnumeratingStream} with a single element, + * which is the result of applying the collector to the entire stream + */ + UniEnumeratingStream groupBy(UniNeighborhoodsCollector collector); + + /** + * Groups the stream by a key and applies a collector to each group, + * producing one pair (key, result) per group. + * + * @param key mapping function to extract the group key + * @param collector the collector to apply to each group + * @param the type of the group key + * @param the type of the collected result + * @return a {@link BiEnumeratingStream} where the first fact is the group key + * and the second fact is the collected result for that group + */ + BiEnumeratingStream groupBy( + UniNeighborhoodsMapper key, + UniNeighborhoodsCollector collector); + /** * Transforms the stream in such a way that all the tuples going through it are distinct. * (No two tuples will {@link Object#equals(Object) equal}.) @@ -437,6 +470,9 @@ BiEnumeratingStream map( * However, operations such as {@link #map(UniNeighborhoodsMapper)} may create a stream which breaks that promise. * By calling this method on such a stream, * duplicate copies of the same tuple will be omitted at a performance cost. + * + * @return a stream that is guaranteed to have distinct tuples, + * at the cost of increased time and memory usage */ UniEnumeratingStream distinct(); diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollector.java new file mode 100644 index 00000000000..98bfe7b80f8 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollector.java @@ -0,0 +1,29 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.BiEnumeratingStream; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +/** + * As defined by {@link UniNeighborhoodsCollector}, only for {@link BiEnumeratingStream}. + * + * @param the type of the solution + * @param the type of the first fact in the source stream's tuple + * @param the type of the second fact in the source stream's tuple + * @param the mutable accumulation type (often hidden as an implementation detail) + * @param the type of the result + */ +@NullMarked +public interface BiNeighborhoodsCollector { + + Supplier supplier(); + + BiNeighborhoodsCollectorAccumulator accumulator(); + + Function finisher(); + +} diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollectorAccumulator.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollectorAccumulator.java new file mode 100644 index 00000000000..9bdc6b926a8 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollectorAccumulator.java @@ -0,0 +1,20 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import ai.timefold.solver.core.preview.api.move.SolutionView; + +import org.jspecify.annotations.NullMarked; + +/** + * As defined by {@link UniNeighborhoodsCollectorAccumulator}, only for {@link BiNeighborhoodsCollector}. + * + * @param the type of the solution + * @param the type of the first fact in the source stream's tuple + * @param the type of the second fact in the source stream's tuple + * @param the mutable accumulation type + */ +@NullMarked +public interface BiNeighborhoodsCollectorAccumulator { + + BiNeighborhoodsCollectorValueHandle intoGroup(SolutionView view, ResultContainer_ container); + +} diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollectorValueHandle.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollectorValueHandle.java new file mode 100644 index 00000000000..d501fdb694f --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/BiNeighborhoodsCollectorValueHandle.java @@ -0,0 +1,24 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +/** + * As defined by {@link UniNeighborhoodsCollectorValueHandle}, only for {@link BiNeighborhoodsCollector}. + * + * @param the type of the first fact in the source stream's tuple + * @param the type of the second fact in the source stream's tuple + */ +@NullMarked +public interface BiNeighborhoodsCollectorValueHandle { + + void add(@Nullable A a, @Nullable B b); + + default void replaceWith(@Nullable A a, @Nullable B b) { + remove(); + add(a, b); + } + + void remove(); + +} diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/NeighborhoodsCollectors.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/NeighborhoodsCollectors.java new file mode 100644 index 00000000000..06f0aa9b4df --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/NeighborhoodsCollectors.java @@ -0,0 +1,112 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; + +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.AndThenBiNeighborhoodsCollector; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.AndThenUniNeighborhoodsCollector; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.ComposeTwoBiNeighborhoodsCollector; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.ComposeTwoUniNeighborhoodsCollector; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.ToListBiNeighborhoodsCollector; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector.ToListUniNeighborhoodsCollector; +import ai.timefold.solver.core.impl.util.ConstantLambdaUtils; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.BiEnumeratingStream; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.UniEnumeratingStream; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; + +import org.jspecify.annotations.NullMarked; + +/** + * Factory for {@link UniNeighborhoodsCollector} and {@link BiNeighborhoodsCollector} instances. + */ +@NullMarked +public final class NeighborhoodsCollectors { + + private NeighborhoodsCollectors() { + } + + /** + * As defined by {@link #toList(UniNeighborhoodsMapper)}, + * but using the fact directly without mapping. + */ + public static UniNeighborhoodsCollector> toList() { + return toList(ConstantLambdaUtils.neighborhoodsUniPickFirst()); + } + + /** + * Collects all facts from a {@link UniEnumeratingStream} group into a {@link List}, + * applying the given mapping function to each before collecting. + * + * @param mapper maps each tuple to a single value to collect + */ + public static UniNeighborhoodsCollector> + toList(UniNeighborhoodsMapper mapper) { + return ToListUniNeighborhoodsCollector.create(mapper); + } + + /** + * Collects all facts from a {@link BiEnumeratingStream} group into a {@link List}, + * applying the given mapping function to each {@code (A, B)} pair before collecting. + * The mapping function also has access to the working solution via {@link BiNeighborhoodsMapper}. + * + * @param mapper maps each tuple to a single value to collect + */ + public static BiNeighborhoodsCollector> toList( + BiNeighborhoodsMapper mapper) { + return ToListBiNeighborhoodsCollector.create(mapper); + } + + /** + * Collects results from a {@link UniNeighborhoodsCollector} and maps its result to another value. + *

+ * This is a better performing alternative to {@code .groupBy(...).map(...)}. + * + * @param the type of the solution + * @param generic type of the tuple variable + * @param generic type of the delegate's return value + * @param generic type of the final collector's return value + * @param delegate the underlying collector to delegate to + * @param mappingFunction maps the result of the underlying collector to another value + */ + public static + UniNeighborhoodsCollector + collectAndThen(UniNeighborhoodsCollector delegate, + Function mappingFunction) { + return new AndThenUniNeighborhoodsCollector<>(delegate, mappingFunction); + } + + /** + * As defined by {@link #collectAndThen(UniNeighborhoodsCollector, Function)}. + */ + public static + BiNeighborhoodsCollector + collectAndThen(BiNeighborhoodsCollector delegate, + Function mappingFunction) { + return new AndThenBiNeighborhoodsCollector<>(delegate, mappingFunction); + } + + /** + * Composes two {@link UniNeighborhoodsCollector}s into one, combining their results with the given function. + */ + public static + UniNeighborhoodsCollector compose( + UniNeighborhoodsCollector first, + UniNeighborhoodsCollector second, + BiFunction composeFunction) { + return new ComposeTwoUniNeighborhoodsCollector<>(first, second, composeFunction); + } + + /** + * Composes two {@link BiNeighborhoodsCollector}s into one, combining their results with the given function. + */ + public static + BiNeighborhoodsCollector compose( + BiNeighborhoodsCollector first, + BiNeighborhoodsCollector second, + BiFunction composeFunction) { + return new ComposeTwoBiNeighborhoodsCollector<>(first, second, composeFunction); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollector.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollector.java new file mode 100644 index 00000000000..399a07aacfb --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollector.java @@ -0,0 +1,32 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import java.util.function.Function; +import java.util.function.Supplier; + +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.UniEnumeratingStream; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +/** + * Collects the facts from a {@link UniEnumeratingStream} group into a result. + * Used with {@link UniEnumeratingStream#groupBy}. + *

+ * Custom implementations should implement {@link Object#equals(Object)} and {@link Object#hashCode()} + * based on their fields to allow node sharing. + * + * @param the type of the solution + * @param the type of the only fact in the source stream's tuple + * @param the mutable accumulation type (often hidden as an implementation detail) + * @param the type of the result + */ +@NullMarked +public interface UniNeighborhoodsCollector { + + Supplier supplier(); + + UniNeighborhoodsCollectorAccumulator accumulator(); + + Function finisher(); + +} diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollectorAccumulator.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollectorAccumulator.java new file mode 100644 index 00000000000..b6a2e74e481 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollectorAccumulator.java @@ -0,0 +1,31 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import ai.timefold.solver.core.preview.api.move.SolutionView; + +import org.jspecify.annotations.NullMarked; + +/** + * Accumulates values into a group container for a {@link UniNeighborhoodsCollector}. + * Created once per group, then called for each value to obtain a {@link UniNeighborhoodsCollectorValueHandle} + * for inserting, updating, and removing values. + * + * @param the type of the solution + * @param the type of the only fact in the source stream's tuple + * @param the mutable accumulation type + */ +@NullMarked +public interface UniNeighborhoodsCollectorAccumulator { + + /** + * Called when a new value enters the group. + * The returned handle is used to insert the value ({@link UniNeighborhoodsCollectorValueHandle#add}), + * update it ({@link UniNeighborhoodsCollectorValueHandle#replaceWith}), + * and remove it ({@link UniNeighborhoodsCollectorValueHandle#remove}). + * + * @param view read-only access to the current working solution + * @param container the group's accumulation container + * @return a handle for the value in the group + */ + UniNeighborhoodsCollectorValueHandle intoGroup(SolutionView view, ResultContainer_ container); + +} diff --git a/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollectorValueHandle.java b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollectorValueHandle.java new file mode 100644 index 00000000000..338fa8e0b16 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/collector/UniNeighborhoodsCollectorValueHandle.java @@ -0,0 +1,25 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +/** + * Represents a handle for a single value in a single {@link UniNeighborhoodsCollectorAccumulator} group. + * The handle is obtained from {@link UniNeighborhoodsCollectorAccumulator#intoGroup} when a new value enters the group, + * and is used to update or remove that value. + * + * @param the type of the only fact in the source stream's tuple + */ +@NullMarked +public interface UniNeighborhoodsCollectorValueHandle { + + void add(@Nullable A a); + + default void replaceWith(@Nullable A a) { + remove(); + add(a); + } + + void remove(); + +} diff --git a/core/src/main/java/module-info.java b/core/src/main/java/module-info.java index c262c15d4b0..4c6d9e1cc37 100644 --- a/core/src/main/java/module-info.java +++ b/core/src/main/java/module-info.java @@ -67,6 +67,7 @@ exports ai.timefold.solver.core.preview.api.move.test; exports ai.timefold.solver.core.preview.api.neighborhood; exports ai.timefold.solver.core.preview.api.neighborhood.stream; + exports ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector; exports ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating; exports ai.timefold.solver.core.preview.api.neighborhood.stream.function; exports ai.timefold.solver.core.preview.api.neighborhood.stream.joiner; diff --git a/core/src/test/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/NeighborhoodsCollectorsTest.java b/core/src/test/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/NeighborhoodsCollectorsTest.java new file mode 100644 index 00000000000..a94877fedb0 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/collector/NeighborhoodsCollectorsTest.java @@ -0,0 +1,180 @@ +package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.collector; + +import static ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.NeighborhoodsCollectors.collectAndThen; +import static ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.NeighborhoodsCollectors.compose; +import static ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.NeighborhoodsCollectors.toList; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.function.BiFunction; + +import ai.timefold.solver.core.impl.util.Pair; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.BiNeighborhoodsCollectorValueHandle; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorValueHandle; + +import org.jspecify.annotations.NullMarked; +import org.junit.jupiter.api.Test; + +@NullMarked +class NeighborhoodsCollectorsTest { + + // ************************************************************************ + // Helpers + // ************************************************************************ + + @SuppressWarnings("unchecked") + private static UniNeighborhoodsCollectorValueHandle accumulate( + UniNeighborhoodsCollector collector, Object container, A value) { + var slot = collector.accumulator() + .intoGroup(null, (C_) container); + slot.add(value); + return slot; + } + + @SuppressWarnings("unchecked") + private static void assertResult( + UniNeighborhoodsCollector collector, Object container, R_ expected) { + assertThat(collector.finisher().apply((C_) container)) + .as("Collector (" + collector + ") did not produce expected result.") + .isEqualTo(expected); + } + + @SuppressWarnings("unchecked") + private static BiNeighborhoodsCollectorValueHandle accumulate( + BiNeighborhoodsCollector collector, Object container, A a, B b) { + var slot = collector.accumulator() + .intoGroup(null, (C_) container); + slot.add(a, b); + return slot; + } + + @SuppressWarnings("unchecked") + private static void assertResult( + BiNeighborhoodsCollector collector, Object container, R_ expected) { + assertThat(collector.finisher().apply((C_) container)) + .as("Collector (" + collector + ") did not produce expected result.") + .isEqualTo(expected); + } + + // ************************************************************************ + // toList + // ************************************************************************ + + @Test + void uniToList() { + UniNeighborhoodsCollector> collector = toList(); + Object container = collector.supplier().get(); + + assertResult(collector, container, List.of()); + var h1 = accumulate(collector, container, 1); + assertResult(collector, container, List.of(1)); + var h2 = accumulate(collector, container, 2); + assertResult(collector, container, List.of(1, 2)); + h1.remove(); + assertResult(collector, container, List.of(2)); + h2.remove(); + assertResult(collector, container, List.of()); + } + + @Test + void biToList() { + BiNeighborhoodsCollector> collector = + toList((view, a, b) -> a); + Object container = collector.supplier().get(); + + assertResult(collector, container, List.of()); + var h1 = accumulate(collector, container, 1, "x"); + assertResult(collector, container, List.of(1)); + var h2 = accumulate(collector, container, 2, "y"); + assertResult(collector, container, List.of(1, 2)); + h1.remove(); + assertResult(collector, container, List.of(2)); + h2.remove(); + assertResult(collector, container, List.of()); + } + + // ************************************************************************ + // collectAndThen + // ************************************************************************ + + @Test + void uniCollectAndThen() { + UniNeighborhoodsCollector collector = + collectAndThen(toList(), List::size); + Object container = collector.supplier().get(); + + assertResult(collector, container, 0); + var h1 = accumulate(collector, container, 1); + assertResult(collector, container, 1); + var h2 = accumulate(collector, container, 2); + assertResult(collector, container, 2); + h1.remove(); + assertResult(collector, container, 1); + h2.remove(); + assertResult(collector, container, 0); + } + + @Test + void biCollectAndThen() { + BiNeighborhoodsCollector collector = + collectAndThen(toList((view, a, b) -> a), List::size); + Object container = collector.supplier().get(); + + assertResult(collector, container, 0); + var h1 = accumulate(collector, container, 1, "x"); + assertResult(collector, container, 1); + var h2 = accumulate(collector, container, 2, "y"); + assertResult(collector, container, 2); + h1.remove(); + assertResult(collector, container, 1); + h2.remove(); + assertResult(collector, container, 0); + } + + // ************************************************************************ + // compose (Uni) + // ************************************************************************ + + @Test + void uniCompose2() { + UniNeighborhoodsCollector, List>> collector = + compose(toList(), toList(), + (BiFunction, List, Pair, List>>) Pair::new); + Object container = collector.supplier().get(); + + assertResult(collector, container, new Pair<>(List.of(), List.of())); + var h1 = accumulate(collector, container, 1); + assertResult(collector, container, new Pair<>(List.of(1), List.of(1))); + var h2 = accumulate(collector, container, 2); + assertResult(collector, container, new Pair<>(List.of(1, 2), List.of(1, 2))); + h1.remove(); + assertResult(collector, container, new Pair<>(List.of(2), List.of(2))); + h2.remove(); + assertResult(collector, container, new Pair<>(List.of(), List.of())); + } + + // ************************************************************************ + // compose (Bi) + // ************************************************************************ + + @Test + void biCompose2() { + BiNeighborhoodsCollector, List>> collector = + compose(toList((view, a, b) -> a), toList((view, a, b) -> a), + (BiFunction, List, Pair, List>>) Pair::new); + Object container = collector.supplier().get(); + + assertResult(collector, container, new Pair<>(List.of(), List.of())); + var h1 = accumulate(collector, container, 1, "x"); + assertResult(collector, container, new Pair<>(List.of(1), List.of(1))); + var h2 = accumulate(collector, container, 2, "y"); + assertResult(collector, container, new Pair<>(List.of(1, 2), List.of(1, 2))); + h1.remove(); + assertResult(collector, container, new Pair<>(List.of(2), List.of(2))); + h2.remove(); + assertResult(collector, container, new Pair<>(List.of(), List.of())); + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStreamTest.java b/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStreamTest.java new file mode 100644 index 00000000000..ae00fe7d963 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/BiEnumeratingStreamTest.java @@ -0,0 +1,113 @@ +package ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating; + +import static org.assertj.core.api.Assertions.assertThat; + +import ai.timefold.solver.core.api.score.SimpleScore; +import ai.timefold.solver.core.config.solver.EnvironmentMode; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.DatasetSession; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.DatasetSessionFactory; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.EnumeratingStreamFactory; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni.AbstractUniEnumeratingStream; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni.UniLeftDataset; +import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni.UniLeftDatasetInstance; +import ai.timefold.solver.core.impl.score.director.SessionContext; +import ai.timefold.solver.core.impl.score.director.easy.EasyScoreDirectorFactory; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.NeighborhoodsCollectors; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.BiNeighborhoodsMapper; +import ai.timefold.solver.core.preview.api.neighborhood.stream.joiner.NeighborhoodsJoiners; +import ai.timefold.solver.core.testdomain.TestdataEntity; +import ai.timefold.solver.core.testdomain.TestdataSolution; +import ai.timefold.solver.core.testdomain.TestdataValue; + +import org.jspecify.annotations.NullMarked; +import org.junit.jupiter.api.Test; + +@NullMarked +class BiEnumeratingStreamTest { + + // ************************************************************************ + // Helpers + // ************************************************************************ + + private static EnumeratingStreamFactory factory() { + return new EnumeratingStreamFactory<>(TestdataSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); + } + + private static AbstractUniEnumeratingStream entityStream( + EnumeratingStreamFactory factory) { + return (AbstractUniEnumeratingStream) factory + .forEachNonDiscriminating(TestdataEntity.class, false); + } + + private static UniLeftDatasetInstance getInstance(DatasetSession session, + UniLeftDataset dataset) { + return (UniLeftDatasetInstance) session.getInstance(dataset); + } + + private static DatasetSession createSession( + EnumeratingStreamFactory enumeratingStreamFactory, + TestdataSolution solution) { + var scoreDirector = + new EasyScoreDirectorFactory<>(enumeratingStreamFactory.getSolutionDescriptor(), s -> SimpleScore.ZERO, + EnvironmentMode.PHASE_ASSERT) + .buildScoreDirector(); + scoreDirector.setWorkingSolution(solution); + var sessionContext = new SessionContext<>(scoreDirector); + var datasetSessionFactory = new DatasetSessionFactory<>(enumeratingStreamFactory); + var datasetSession = datasetSessionFactory.buildSession(sessionContext); + enumeratingStreamFactory.getSolutionDescriptor().visitAll(solution, datasetSession::insert); + datasetSession.settle(); + return datasetSession; + } + + // ************************************************************************ + // groupBy + // ************************************************************************ + + @Test + void groupBy_1Mapping0Collector() { + var factory = factory(); + var entityStream = entityStream(factory); + var valueStream = (AbstractUniEnumeratingStream) factory + .forEachNonDiscriminating(TestdataValue.class, false); + var biStream = entityStream.join(valueStream, + NeighborhoodsJoiners.equal(TestdataEntity::getValue, v -> v)); + BiNeighborhoodsMapper byValueCode = + (view, entity, value) -> value.getCode(); + var groupedStream = (AbstractUniEnumeratingStream) biStream.groupBy(byValueCode); + var dataset = groupedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA) + .containsExactlyInAnyOrder("Generated Value 0", "Generated Value 1"); + } + + @Test + void groupBy_1Mapping1Collector() { + var factory = factory(); + var entityStream = entityStream(factory); + var valueStream = (AbstractUniEnumeratingStream) factory + .forEachNonDiscriminating(TestdataValue.class, false); + var biStream = entityStream.join(valueStream, + NeighborhoodsJoiners.equal(TestdataEntity::getValue, v -> v)); + BiNeighborhoodsMapper byValue = + (view, entity, value) -> value; + var groupedStream = biStream.groupBy(byValue, + NeighborhoodsCollectors.toList((view, entity, value) -> entity.getCode())); + var mappedStream = (AbstractUniEnumeratingStream) groupedStream + .map((view, value, entityCodes) -> value.getCode() + "=" + entityCodes.size()); + var dataset = mappedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA) + .containsExactlyInAnyOrder("Generated Value 0=2", "Generated Value 1=2"); + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStreamTest.java b/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStreamTest.java index bbf45126191..935288305ef 100644 --- a/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStreamTest.java +++ b/core/src/test/java/ai/timefold/solver/core/preview/api/neighborhood/stream/enumerating/UniEnumeratingStreamTest.java @@ -3,6 +3,9 @@ import static org.assertj.core.api.Assertions.assertThat; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Supplier; import ai.timefold.solver.core.api.score.SimpleScore; import ai.timefold.solver.core.api.solver.SolutionManager; @@ -16,20 +19,71 @@ import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni.UniLeftDatasetInstance; import ai.timefold.solver.core.impl.score.director.SessionContext; import ai.timefold.solver.core.impl.score.director.easy.EasyScoreDirectorFactory; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.NeighborhoodsCollectors; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollector; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorAccumulator; +import ai.timefold.solver.core.preview.api.neighborhood.stream.enumerating.collector.UniNeighborhoodsCollectorValueHandle; +import ai.timefold.solver.core.preview.api.neighborhood.stream.function.UniNeighborhoodsMapper; import ai.timefold.solver.core.testdomain.TestdataEntity; import ai.timefold.solver.core.testdomain.TestdataSolution; +import ai.timefold.solver.core.testdomain.TestdataValue; import ai.timefold.solver.core.testdomain.list.TestdataListEntity; import ai.timefold.solver.core.testdomain.list.TestdataListSolution; import ai.timefold.solver.core.testdomain.list.pinned.index.TestdataPinnedWithIndexListEntity; import ai.timefold.solver.core.testdomain.list.pinned.index.TestdataPinnedWithIndexListSolution; import ai.timefold.solver.core.testdomain.list.pinned.index.TestdataPinnedWithIndexListValue; +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; +@NullMarked class UniEnumeratingStreamTest { + // ************************************************************************ + // Helpers + // ************************************************************************ + + private static EnumeratingStreamFactory factory() { + return new EnumeratingStreamFactory<>(TestdataSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); + } + + private static AbstractUniEnumeratingStream entityStream( + EnumeratingStreamFactory factory) { + return (AbstractUniEnumeratingStream) factory + .forEachNonDiscriminating(TestdataEntity.class, false); + } + + private static UniLeftDatasetInstance getInstance(DatasetSession session, + UniLeftDataset dataset) { + return (UniLeftDatasetInstance) session.getInstance(dataset); + } + + private static DatasetSession createSession( + EnumeratingStreamFactory enumeratingStreamFactory, + Solution_ solution) { + var scoreDirector = + new EasyScoreDirectorFactory<>(enumeratingStreamFactory.getSolutionDescriptor(), s -> SimpleScore.ZERO, + EnvironmentMode.PHASE_ASSERT) + .buildScoreDirector(); + scoreDirector.setWorkingSolution(solution); + var sessionContext = new SessionContext<>(scoreDirector); + var datasetSessionFactory = new DatasetSessionFactory<>(enumeratingStreamFactory); + var datasetSession = datasetSessionFactory.buildSession(sessionContext); + var solutionDescriptor = enumeratingStreamFactory.getSolutionDescriptor(); + + solutionDescriptor.visitAll(solution, datasetSession::insert); + + datasetSession.settle(); + return datasetSession; + } + + // ************************************************************************ + // forEach + // ************************************************************************ + @Test - void forEachBasicVariable() { + void forEach_basicVar() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); var uniDataset = ((AbstractUniEnumeratingStream) enumeratingStreamFactory @@ -37,8 +91,8 @@ void forEachBasicVariable() { .createLeftDataset(); var solution = TestdataSolution.generateSolution(2, 2); - var datasetSession = UniEnumeratingStreamTest.createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var datasetSession = createSession(enumeratingStreamFactory, solution); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); var entity1 = solution.getEntityList().get(0); var entity2 = solution.getEntityList().get(1); @@ -49,7 +103,7 @@ void forEachBasicVariable() { .containsExactly(entity1, entity2); // Make incremental changes. - var entity3 = new TestdataEntity("entity3", solution.getValueList().get(0)); + var entity3 = new TestdataEntity("entity3", solution.getValueList().getFirst()); datasetSession.insert(entity3); datasetSession.retract(entity2); datasetSession.settle(); @@ -60,13 +114,8 @@ void forEachBasicVariable() { .containsExactly(entity1, entity3); } - private static UniLeftDatasetInstance getDatasetInstance(DatasetSession session, - UniLeftDataset dataset) { - return (UniLeftDatasetInstance) session.getInstance(dataset); - } - @Test - void forEachBasicVariableIncludingNull() { + void forEachIncludingNull_basicVar() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); var uniDataset = ((AbstractUniEnumeratingStream) enumeratingStreamFactory @@ -74,8 +123,8 @@ void forEachBasicVariableIncludingNull() { .createLeftDataset(); var solution = TestdataSolution.generateSolution(2, 2); - var datasetSession = UniEnumeratingStreamTest.createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var datasetSession = createSession(enumeratingStreamFactory, solution); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); var entity1 = solution.getEntityList().get(0); var entity2 = solution.getEntityList().get(1); @@ -86,7 +135,7 @@ void forEachBasicVariableIncludingNull() { .containsExactly(null, entity1, entity2); // Make incremental changes. - var entity3 = new TestdataEntity("entity3", solution.getValueList().get(0)); + var entity3 = new TestdataEntity("entity3", solution.getValueList().getFirst()); datasetSession.insert(entity3); datasetSession.retract(entity2); datasetSession.settle(); @@ -98,7 +147,7 @@ void forEachBasicVariableIncludingNull() { } @Test - void forEachListVariable() { + void forEach_listVar() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); var uniDataset = ((AbstractUniEnumeratingStream) enumeratingStreamFactory @@ -107,7 +156,7 @@ void forEachListVariable() { var solution = TestdataListSolution.generateInitializedSolution(2, 2); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); var entity1 = solution.getEntityList().get(0); var entity2 = solution.getEntityList().get(1); @@ -130,7 +179,7 @@ void forEachListVariable() { } @Test - void forEachListVariableIncludingNull() { + void forEachIncludingNull_listVar() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); var uniDataset = ((AbstractUniEnumeratingStream) enumeratingStreamFactory @@ -139,7 +188,7 @@ void forEachListVariableIncludingNull() { var solution = TestdataListSolution.generateInitializedSolution(2, 2); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); var entity1 = solution.getEntityList().get(0); var entity2 = solution.getEntityList().get(1); @@ -161,27 +210,8 @@ void forEachListVariableIncludingNull() { .containsExactly(null, entity1, entity3); } - private static DatasetSession createSession( - EnumeratingStreamFactory enumeratingStreamFactory, - Solution_ solution) { - var scoreDirector = - new EasyScoreDirectorFactory<>(enumeratingStreamFactory.getSolutionDescriptor(), s -> SimpleScore.ZERO, - EnvironmentMode.PHASE_ASSERT) - .buildScoreDirector(); - scoreDirector.setWorkingSolution(solution); - var sessionContext = new SessionContext<>(scoreDirector); - var datasetSessionFactory = new DatasetSessionFactory<>(enumeratingStreamFactory); - var datasetSession = datasetSessionFactory.buildSession(sessionContext); - var solutionDescriptor = enumeratingStreamFactory.getSolutionDescriptor(); - - solutionDescriptor.visitAll(solution, datasetSession::insert); - - datasetSession.settle(); - return datasetSession; - } - @Test - void forEachListVariableIncludingPinned() { + void forEach_listVarIncludingPinned() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); @@ -204,7 +234,7 @@ void forEachListVariableIncludingPinned() { unpinnedEntity.setPinIndex(0); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -225,7 +255,7 @@ void forEachListVariableIncludingPinned() { } @Test - void forEachListVariableIncludingPinnedAndNull() { + void forEachIncludingNull_listVarIncludingPinned() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); @@ -248,7 +278,7 @@ void forEachListVariableIncludingPinnedAndNull() { unpinnedEntity.setPinIndex(0); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -269,7 +299,7 @@ void forEachListVariableIncludingPinnedAndNull() { } @Test - void forEachListVariableExcludingPinned() { // Entities with planningPin true will be skipped. + void forEachExcludingPinned_listVar() { // Entities with planningPin true will be skipped. var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); @@ -293,7 +323,7 @@ void forEachListVariableExcludingPinned() { // Entities with planningPin true wi unpinnedEntity.setPinIndex(0); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -314,7 +344,7 @@ void forEachListVariableExcludingPinned() { // Entities with planningPin true wi } @Test - void forEachListVariableExcludingPinnedIncludingNull() { // Entities with planningPin true will be skipped. + void forEachExcludingPinnedIncludingNull_listVar() { // Entities with planningPin true will be skipped. var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); @@ -338,7 +368,7 @@ void forEachListVariableExcludingPinnedIncludingNull() { // Entities with planni unpinnedEntity.setPinIndex(0); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -359,7 +389,7 @@ void forEachListVariableExcludingPinnedIncludingNull() { // Entities with planni } @Test - void forEachListVariableIncludingPinnedValues() { + void forEach_listVarIncludingPinnedValues() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); @@ -376,7 +406,7 @@ void forEachListVariableIncludingPinnedValues() { var value4 = solution.getValueList().get(3); var unassignedValue = solution.getValueList().get(4); // 1 value, entity pinned. - var fullyPinnedEntity = solution.getEntityList().get(0); + var fullyPinnedEntity = solution.getEntityList().getFirst(); fullyPinnedEntity.setPinned(true); fullyPinnedEntity.setValueList(List.of(value1)); // 2 values, 1 pinned. @@ -392,7 +422,7 @@ void forEachListVariableIncludingPinnedValues() { SolutionManager.updateShadowVariables(solution); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -401,7 +431,7 @@ void forEachListVariableIncludingPinnedValues() { } @Test - void forEachListVariableIncludingPinnedValuesAndNull() { + void forEachIncludingNull_listVarIncludingPinnedValues() { var enumeratingStreamFactory = new EnumeratingStreamFactory<>(TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(), EnvironmentMode.PHASE_ASSERT); @@ -418,7 +448,7 @@ void forEachListVariableIncludingPinnedValuesAndNull() { var value4 = solution.getValueList().get(3); var unassignedValue = solution.getValueList().get(4); // 1 value, entity pinned. - var fullyPinnedEntity = solution.getEntityList().get(0); + var fullyPinnedEntity = solution.getEntityList().getFirst(); fullyPinnedEntity.setPinned(true); fullyPinnedEntity.setValueList(List.of(value1)); // 2 values, 1 pinned. @@ -434,7 +464,7 @@ void forEachListVariableIncludingPinnedValuesAndNull() { SolutionManager.updateShadowVariables(solution); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -443,7 +473,7 @@ void forEachListVariableIncludingPinnedValuesAndNull() { } @Test - void forEachListVariableExcludingPinnedValues() { + void forEachExcludingPinned_listVarValues() { var solutionDescriptor = TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(); var enumeratingStreamFactory = new EnumeratingStreamFactory<>(solutionDescriptor, EnvironmentMode.PHASE_ASSERT); var uniDataset = @@ -459,7 +489,7 @@ void forEachListVariableExcludingPinnedValues() { var value3 = solution.getValueList().get(3); var value4 = solution.getValueList().get(4); // Initially unassigned. // 1 value, entity pinned. - var fullyPinnedEntity = solution.getEntityList().get(0); + var fullyPinnedEntity = solution.getEntityList().getFirst(); fullyPinnedEntity.setPinned(true); fullyPinnedEntity.setValueList(List.of(value0)); // 2 values, 1 pinned. @@ -478,7 +508,7 @@ void forEachListVariableExcludingPinnedValues() { SolutionManager.updateShadowVariables(solution); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -487,7 +517,7 @@ void forEachListVariableExcludingPinnedValues() { } @Test - void forEachListVariableExcludingPinnedValuesIncludingNull() { + void forEachExcludingPinnedIncludingNull_listVarValues() { var solutionDescriptor = TestdataPinnedWithIndexListSolution.buildSolutionDescriptor(); var enumeratingStreamFactory = new EnumeratingStreamFactory<>(solutionDescriptor, EnvironmentMode.PHASE_ASSERT); var uniDataset = @@ -503,7 +533,7 @@ void forEachListVariableExcludingPinnedValuesIncludingNull() { var value3 = solution.getValueList().get(3); var value4 = solution.getValueList().get(4); // Initially unassigned. // 1 value, entity pinned. - var fullyPinnedEntity = solution.getEntityList().get(0); + var fullyPinnedEntity = solution.getEntityList().getFirst(); fullyPinnedEntity.setPinned(true); fullyPinnedEntity.setValueList(List.of(value0)); // 2 values, 1 pinned. @@ -522,7 +552,7 @@ void forEachListVariableExcludingPinnedValuesIncludingNull() { SolutionManager.updateShadowVariables(solution); var datasetSession = createSession(enumeratingStreamFactory, solution); - var uniDatasetInstance = getDatasetInstance(datasetSession, uniDataset); + var uniDatasetInstance = getInstance(datasetSession, uniDataset); assertThat(uniDatasetInstance.iterator()) .toIterable() @@ -530,4 +560,198 @@ void forEachListVariableExcludingPinnedValuesIncludingNull() { .containsExactly(null, value2, value3, value4); } + // ************************************************************************ + // groupBy + // ************************************************************************ + + @Test + void groupBy_1Mapping0Collector() { + var factory = factory(); + UniNeighborhoodsMapper byValue = + (view, entity) -> entity.getValue(); + var groupedStream = entityStream(factory).groupBy(byValue); + var dataset = groupedStream.createLeftDataset(); + + // generateSolution(2 values, 4 entities): e0→v0, e1→v1, e2→v0, e3→v1 + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + + var v0 = solution.getValueList().get(0); + var v1 = solution.getValueList().get(1); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA).containsExactlyInAnyOrder(v0, v1); + + // Reassign e2 from v0 to v1; v0 still has e0. + var e2 = solution.getEntityList().get(2); + e2.setValue(v1); + session.update(e2); + session.settle(); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA).containsExactlyInAnyOrder(v0, v1); + + // Reassign e0 from v0 to v1 → v0 group disappears. + var e0 = solution.getEntityList().getFirst(); + e0.setValue(v1); + session.update(e0); + session.settle(); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA).containsExactly(v1); + } + + @Test + void groupBy_1Mapping1Collector() { + var factory = factory(); + UniNeighborhoodsMapper byValue = + (view, entity) -> entity.getValue(); + var groupedStream = entityStream(factory).groupBy(byValue, NeighborhoodsCollectors.toList()); + var mappedStream = (AbstractUniEnumeratingStream) groupedStream + .map((view, value, entities) -> value.getCode() + "=" + entities.size()); + var dataset = mappedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA) + .containsExactlyInAnyOrder("Generated Value 0=2", "Generated Value 1=2"); + + var v0 = solution.getValueList().getFirst(); + var newEntity = new TestdataEntity("New Entity", v0); + solution.getEntityList().add(newEntity); + session.insert(newEntity); + session.settle(); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA) + .containsExactlyInAnyOrder("Generated Value 0=3", "Generated Value 1=2"); + } + + @Test + void groupBy_0Mapping1Collector() { + var factory = factory(); + var groupedStream = entityStream(factory) + .groupBy(NeighborhoodsCollectors. toList()); + var dataset = groupedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 3); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + + assertThat(instance.iterator()).toIterable() + .map(UniTuple::getA) + .hasSize(1); + + var v0 = solution.getValueList().getFirst(); + var newEntity = new TestdataEntity("New Entity", v0); + solution.getEntityList().add(newEntity); + session.insert(newEntity); + session.settle(); + + assertThat(instance.iterator()).toIterable() + .map(UniTuple::getA) + .hasSize(1); + } + + @Test + void distinct() { + var factory = factory(); + var mappedStream = (AbstractUniEnumeratingStream) entityStream(factory) + .map((view, entity) -> entity.getValue()) + .distinct(); + var dataset = mappedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + var v0 = solution.getValueList().get(0); + var v1 = solution.getValueList().get(1); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA).containsExactlyInAnyOrder(v0, v1); + } + + @Test + void groupBy_1Mapping1Collector_customCollector() { + var factory = factory(); + UniNeighborhoodsMapper byValue = + (view, entity) -> entity.getValue(); + UniNeighborhoodsCollector countCollector = + new NeighborhoodsUniCountCollector(); + + var groupedStream = entityStream(factory).groupBy(byValue, countCollector); + var mappedStream = (AbstractUniEnumeratingStream) groupedStream + .map((view, value, count) -> value.getCode() + "=" + count); + var dataset = mappedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA) + .containsExactlyInAnyOrder("Generated Value 0=2", "Generated Value 1=2"); + + var e0 = solution.getEntityList().getFirst(); + session.retract(e0); + session.settle(); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA) + .containsExactlyInAnyOrder("Generated Value 0=1", "Generated Value 1=2"); + } + + @Test + void groupBy_1Mapping0Collector_solutionView() { + var factory = factory(); + var accessCount = new AtomicInteger(); + UniNeighborhoodsMapper mapper = (view, entity) -> { + accessCount.incrementAndGet(); + return entity.getValue(); + }; + var groupedStream = entityStream(factory).groupBy(mapper); + var dataset = groupedStream.createLeftDataset(); + + var solution = TestdataSolution.generateSolution(2, 4); + var session = createSession(factory, solution); + var instance = getInstance(session, dataset); + var v0 = solution.getValueList().get(0); + var v1 = solution.getValueList().get(1); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA).containsExactlyInAnyOrder(v0, v1); + assertThat(accessCount.get()).isPositive(); + + var e2 = solution.getEntityList().get(2); + var oldCount = accessCount.get(); + e2.setValue(v1); + session.update(e2); + session.settle(); + + assertThat(instance.iterator()).toIterable().map(UniTuple::getA).containsExactlyInAnyOrder(v0, v1); + assertThat(accessCount.get()).isGreaterThan(oldCount); + } + + private static final class NeighborhoodsUniCountCollector + implements UniNeighborhoodsCollector { + @Override + public Supplier supplier() { + return () -> new int[] { 0 }; + } + + @Override + public UniNeighborhoodsCollectorAccumulator accumulator() { + return (view, container) -> new UniNeighborhoodsCollectorValueHandle() { + @Override + public void add(@Nullable TestdataEntity entity) { + container[0]++; + } + + @Override + public void remove() { + container[0]--; + } + }; + } + + @Override + public Function finisher() { + return c -> c[0]; + } + } } diff --git a/docs/src/modules/ROOT/pages/optimization-algorithms/neighborhoods.adoc b/docs/src/modules/ROOT/pages/optimization-algorithms/neighborhoods.adoc index 63781c4eb93..bf441770a28 100644 --- a/docs/src/modules/ROOT/pages/optimization-algorithms/neighborhoods.adoc +++ b/docs/src/modules/ROOT/pages/optimization-algorithms/neighborhoods.adoc @@ -532,7 +532,7 @@ the `MONDAY_MORNING` timeslot is a valid value for the `timeslotVariable` of the If that's not the case, we filter out this lesson from the enumeration and therefore will not generate moves which use this lesson. -The same pattern applies to other building blocks as well, such as `join` and `ifExists`; +The same pattern applies to other building blocks as well, such as `join`, `ifExists` and `groupBy`; essentially, the `solutionView` argument was added to every predicate or function where it could be useful. diff --git a/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleInputMetrics.java b/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleInputMetrics.java index 2447c4b86f4..02ef6e49267 100644 --- a/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleInputMetrics.java +++ b/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleInputMetrics.java @@ -21,7 +21,7 @@ public record EmployeeScheduleInputMetrics( @Extension(name = "x-tf-priority", value = "2"), @Extension(name = "x-tf-example", value = "10") }) int employeeCount) implements - ModelInputMetrics{ + ModelInputMetrics { public static final String INPUT_METRIC_EMPLOYEES = "employeeCount"; public static final String INPUT_METRIC_SHIFTS = "shiftCount"; diff --git a/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleOutputMetrics.java b/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleOutputMetrics.java index f939a981e4f..aa512077b5a 100644 --- a/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleOutputMetrics.java +++ b/model/test-model/src/main/java/ai/timefold/solver/model/testmodel/EmployeeScheduleOutputMetrics.java @@ -16,7 +16,7 @@ public record EmployeeScheduleOutputMetrics( @Extension(name = "x-tf-priority", value = "1"), @Extension(name = "x-tf-example", value = "100") }) int assignedShifts) implements - ModelOutputMetrics{ + ModelOutputMetrics { public static final String OUTPUT_METRIC_SHIFTS = "assignedShifts"; }