Skip to content

Commit 338a410

Browse files
committed
copilot
1 parent 20e1460 commit 338a410

5 files changed

Lines changed: 44 additions & 34 deletions

File tree

core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractBavetNodeNetwork.java

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ public abstract class AbstractBavetNodeNetwork {
3030

3131
protected static AbstractNode[][] buildLayeredNodes(List<AbstractNode> nodeList) {
3232
var layerMap = new TreeMap<Long, List<AbstractNode>>();
33-
nodeList.forEach(node -> layerMap.computeIfAbsent(node.getLayerIndex(), unused -> new ArrayList<>())
34-
.add(node));
33+
nodeList.forEach(node -> layerMap.computeIfAbsent(node.getLayerIndex(), unused -> new ArrayList<>()).add(node));
3534
var layerCount = layerMap.size();
3635
var layeredNodes = new AbstractNode[layerCount][];
3736
for (var i = 0; i < layerCount; i++) {
@@ -73,13 +72,21 @@ public int forEachNodeCount() {
7372
return declaredClassToNodeMap.size();
7473
}
7574

75+
/**
76+
*
77+
* @param factClass
78+
* @return if {@link #isActivationCheckComplete()} is true, only returns active root nodes;
79+
* otherwise returns all root nodes.
80+
* This means that if this information was ever read before activation checks were complete,
81+
* it should be re-read after to make sure no inactive nodes are included.
82+
*/
7683
public Stream<AbstractRootNode<?>> getRootNodesAcceptingType(Class<?> factClass) {
7784
// The node needs to match the fact, or the node needs to be applicable to the entire solution.
7885
// The latter is for FromSolution nodes.
79-
return declaredClassToNodeMap.entrySet()
80-
.stream()
86+
return declaredClassToNodeMap.entrySet().stream()
8187
.flatMap(entry -> entry.getValue().stream())
82-
.filter(tupleSourceRoot -> factClass == PlanningSolution.class || tupleSourceRoot.allowsInstancesOf(factClass));
88+
.filter(tupleSourceRoot -> factClass == PlanningSolution.class || tupleSourceRoot.allowsInstancesOf(factClass))
89+
.filter(node -> !isActivationCheckComplete() || activeNodeSet.contains(node));
8390
}
8491

8592
public void settle() {
@@ -102,11 +109,8 @@ public void settle() {
102109
case AbstractTwoInputNode<?, ?> twoInputNode -> twoInputNode.isActive();
103110
})
104111
.peek(activeNodes::add)
105-
.map(propagatorFunction)
106-
.toArray(Propagator[]::new))
107-
.filter(layer -> layer.length > 0)
108-
.peek(AbstractBavetNodeNetwork::settleLayer)
109-
.toArray(Propagator[][]::new);
112+
.map(propagatorFunction).toArray(Propagator[]::new))
113+
.filter(layer -> layer.length > 0).peek(AbstractBavetNodeNetwork::settleLayer).toArray(Propagator[][]::new);
110114
this.activeNodeSet = activeNodes;
111115
return;
112116
}
@@ -116,15 +120,10 @@ public void settle() {
116120
}
117121
}
118122

119-
protected boolean isActivationCheckComplete() {
123+
public boolean isActivationCheckComplete() {
120124
return layeredActivePropagators != null;
121125
}
122126

123-
/**
124-
* For testing only. The nodes that remained active after {@link #settle()}.
125-
*
126-
* @throws IllegalStateException if called before {@link #settle()}.
127-
*/
128127
Set<AbstractNode> getActiveNodes() {
129128
if (activeNodeSet == null) {
130129
throw new IllegalStateException("Impossible state: getActiveNodes() called before settle().");
@@ -136,9 +135,7 @@ Set<AbstractNode> getActiveNodes() {
136135
* For testing only. All nodes in the network, regardless of activity.
137136
*/
138137
List<AbstractNode> getNodes() {
139-
return Arrays.stream(layeredNodes)
140-
.flatMap(Arrays::stream)
141-
.toList();
138+
return Arrays.stream(layeredNodes).flatMap(Arrays::stream).toList();
142139
}
143140

144141
private static void settleLayer(Propagator[] nodesInLayer) {
@@ -174,8 +171,7 @@ public int hashCode() {
174171

175172
@Override
176173
public String toString() {
177-
return "%s with %d forEach nodes."
178-
.formatted(getClass().getSimpleName(), forEachNodeCount());
174+
return "%s with %d forEach nodes.".formatted(getClass().getSimpleName(), forEachNodeCount());
179175
}
180176

181177
}

core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package ai.timefold.solver.core.impl.bavet;
22

3-
import java.util.IdentityHashMap;
3+
import java.util.Arrays;
4+
import java.util.HashMap;
45
import java.util.Map;
56

67
import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode;
@@ -16,9 +17,9 @@ public abstract class AbstractSession<Network_ extends AbstractBavetNodeNetwork>
1617

1718
protected AbstractSession(Network_ nodeNetwork) {
1819
this.nodeNetwork = nodeNetwork;
19-
this.insertEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
20-
this.updateEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
21-
this.retractEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount());
20+
this.insertEffectiveClassToNodeArrayMap = HashMap.newHashMap(nodeNetwork.forEachNodeCount());
21+
this.updateEffectiveClassToNodeArrayMap = HashMap.newHashMap(nodeNetwork.forEachNodeCount());
22+
this.retractEffectiveClassToNodeArrayMap = HashMap.newHashMap(nodeNetwork.forEachNodeCount());
2223
}
2324

2425
public final void insert(Object fact) {
@@ -67,10 +68,23 @@ public final void settle() {
6768
if (settled) {
6869
return;
6970
}
71+
var deactivationComplete = nodeNetwork.isActivationCheckComplete();
7072
nodeNetwork.settle();
73+
if (!deactivationComplete && nodeNetwork.isActivationCheckComplete()) {
74+
removeInactiveRootNodes(insertEffectiveClassToNodeArrayMap);
75+
removeInactiveRootNodes(updateEffectiveClassToNodeArrayMap);
76+
removeInactiveRootNodes(retractEffectiveClassToNodeArrayMap);
77+
}
7178
settled = true;
7279
}
7380

81+
private void removeInactiveRootNodes(Map<Class<?>, AbstractRootNode<Object>[]> effectiveClassToNodeArrayMap) {
82+
// Use getActiveNodes() for this, to not rerun the activity checking logic again.
83+
effectiveClassToNodeArrayMap.replaceAll((k, v) -> Arrays.stream(v)
84+
.filter(n -> nodeNetwork.getActiveNodes().contains(n))
85+
.toArray(AbstractRootNode[]::new));
86+
}
87+
7488
public Network_ getNodeNetwork() {
7589
return nodeNetwork;
7690
}

core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ public final void retract(@Nullable Object a) {
9292
}
9393

9494
protected abstract Tuple_ remapTuple(Tuple_ tuple);
95+
9596
}

core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ private void recalculateTuples(AbstractBavetNodeNetwork internalNodeNetwork,
249249
internalTupleToOutputTupleMapper, internalTupleToOutputTupleMap))) {
250250
for (var fact : seenFactSet) {
251251
classToRootNodeList.get(fact.getClass())
252-
.forEach(node -> ((BavetRootNode<Object>) node).update(fact));
252+
.forEach(node -> ((AbstractRootNode<Object>) node).update(fact));
253253
}
254254
internalNodeNetwork.settle();
255255
}

core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/AggregatedTupleLifecycle.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package ai.timefold.solver.core.impl.bavet.common.tuple;
22

33
import java.util.Arrays;
4-
import java.util.Objects;
54

65
public final class AggregatedTupleLifecycle<Tuple_ extends Tuple>
76
implements TupleLifecycle<Tuple_> {
@@ -64,12 +63,6 @@ public void retract(Tuple_ tuple) {
6463
}
6564
}
6665

67-
@Override
68-
public boolean equals(Object o) {
69-
return o instanceof AggregatedTupleLifecycle<?> that &&
70-
Arrays.deepEquals(downstream, that.downstream);
71-
}
72-
7366
/**
7467
* Users must not modify this array. (Defensive copy avoided for performance reasons.)
7568
*
@@ -79,9 +72,15 @@ public TupleLifecycle<Tuple_>[] downstream() {
7972
return downstream;
8073
}
8174

75+
@Override
76+
public boolean equals(Object o) {
77+
return o instanceof AggregatedTupleLifecycle<?> that &&
78+
Arrays.deepEquals(downstream, that.downstream);
79+
}
80+
8281
@Override
8382
public int hashCode() {
84-
return Objects.hashCode(downstream);
83+
return Arrays.deepHashCode(downstream);
8584
}
8685

8786
@Override

0 commit comments

Comments
 (0)