Skip to content

Commit 3a1c4cb

Browse files
perf: improve many aspects of DSV performance (#1592)
Less memory allocations. Less map access. Cleaner code structure. --------- Co-authored-by: Christopher Chianelli <christopher@timefold.ai>
1 parent be9ca15 commit 3a1c4cb

22 files changed

Lines changed: 1742 additions & 647 deletions

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/ShadowVariableUpdateHelper.java

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import java.util.IdentityHashMap;
1414
import java.util.List;
1515
import java.util.Map;
16-
import java.util.Set;
1716

1817
import ai.timefold.solver.core.api.domain.entity.PlanningEntity;
1918
import ai.timefold.solver.core.api.domain.variable.CascadingUpdateShadowVariable;
@@ -34,6 +33,7 @@
3433
import ai.timefold.solver.core.impl.domain.variable.declarative.DefaultShadowVariableSessionFactory;
3534
import ai.timefold.solver.core.impl.domain.variable.declarative.DefaultTopologicalOrderGraph;
3635
import ai.timefold.solver.core.impl.domain.variable.declarative.VariableReferenceGraph;
36+
import ai.timefold.solver.core.impl.domain.variable.declarative.VariableReferenceGraphBuilder;
3737
import ai.timefold.solver.core.impl.domain.variable.descriptor.BasicVariableDescriptor;
3838
import ai.timefold.solver.core.impl.domain.variable.descriptor.ShadowVariableDescriptor;
3939
import ai.timefold.solver.core.impl.domain.variable.index.IndexShadowVariableDescriptor;
@@ -72,18 +72,19 @@ private ShadowVariableUpdateHelper(EnumSet<ShadowVariableType> supportedShadowVa
7272
this.supportedShadowVariableTypes = supportedShadowVariableTypes;
7373
}
7474

75+
@SuppressWarnings("unchecked")
7576
public void updateShadowVariables(Solution_ solution) {
76-
var initialSolutionDescriptor = (SolutionDescriptor<Solution_>) SolutionDescriptor.buildSolutionDescriptor(
77-
Set.of(PreviewFeature.DECLARATIVE_SHADOW_VARIABLES),
78-
solution.getClass());
79-
var entityClassList = initialSolutionDescriptor.getAllEntitiesAndProblemFacts(solution)
77+
var enabledPreviewFeatures = EnumSet.of(PreviewFeature.DECLARATIVE_SHADOW_VARIABLES);
78+
var solutionClass = (Class<Solution_>) solution.getClass();
79+
var initialSolutionDescriptor = SolutionDescriptor.buildSolutionDescriptor(
80+
enabledPreviewFeatures, solutionClass);
81+
var entityClassArray = initialSolutionDescriptor.getAllEntitiesAndProblemFacts(solution)
8082
.stream()
8183
.map(Object::getClass)
8284
.distinct()
83-
.toList();
84-
var solutionDescriptor = (SolutionDescriptor<Solution_>) SolutionDescriptor.buildSolutionDescriptor(
85-
Set.of(PreviewFeature.DECLARATIVE_SHADOW_VARIABLES),
86-
solution.getClass(), entityClassList.toArray(Class[]::new));
85+
.toArray(Class[]::new);
86+
var solutionDescriptor = SolutionDescriptor.buildSolutionDescriptor(enabledPreviewFeatures, solutionClass,
87+
entityClassArray);
8788
try (var scoreDirector = new InternalScoreDirector<>(solutionDescriptor)) {
8889
// When we have a solution, we can reuse the logic from VariableListenerSupport to update all variable types
8990
scoreDirector.setWorkingSolution(solution);
@@ -117,9 +118,8 @@ public void updateShadowVariables(Class<Solution_> solutionClass,
117118
.formatted(missingShadowVariableTypeList));
118119
}
119120
// No solution, we trigger all supported events manually
120-
var session = new InternalShadowVariableSession<>(solutionDescriptor,
121-
new VariableReferenceGraph<>(ChangedVariableNotifier.empty()));
122-
session.init(entities);
121+
var session = InternalShadowVariableSession.build(solutionDescriptor,
122+
new VariableReferenceGraphBuilder<>(ChangedVariableNotifier.empty()), entities);
123123
// Update all built-in shadow variables
124124
var listVariableDescriptor = solutionDescriptor.getListVariableDescriptor();
125125
if (listVariableDescriptor == null) {
@@ -135,11 +135,12 @@ public void updateShadowVariables(Class<Solution_> solutionClass,
135135
private record InternalShadowVariableSession<Solution_>(SolutionDescriptor<Solution_> solutionDescriptor,
136136
VariableReferenceGraph<Solution_> graph) {
137137

138-
public void init(Object... entities) {
139-
if (!solutionDescriptor.getDeclarativeShadowVariableDescriptors().isEmpty()) {
140-
DefaultShadowVariableSessionFactory.visitGraph(solutionDescriptor, graph, entities,
141-
DefaultTopologicalOrderGraph::new);
142-
}
138+
public static <Solution_> InternalShadowVariableSession<Solution_> build(
139+
SolutionDescriptor<Solution_> solutionDescriptor, VariableReferenceGraphBuilder<Solution_> graph,
140+
Object... entities) {
141+
return new InternalShadowVariableSession<>(solutionDescriptor,
142+
DefaultShadowVariableSessionFactory.buildGraph(solutionDescriptor, graph, entities,
143+
DefaultTopologicalOrderGraph::new));
143144
}
144145

145146
/**
@@ -249,6 +250,7 @@ public void processListVariable(Object... entities) {
249250
*
250251
* @param entities the entities to be analyzed
251252
*/
253+
@SuppressWarnings("unchecked")
252254
public void processCascadingVariable(Object... entities) {
253255
var listVariableDescriptor = solutionDescriptor.getListVariableDescriptor();
254256
if (listVariableDescriptor != null) {
@@ -336,8 +338,8 @@ private List<BasicVariableDescriptor<Solution_>> fetchBasicDescriptors(EntityDes
336338
}
337339
}
338340

339-
private static class InternalScoreDirectorFactory<Solution_, Score_ extends Score<Score_>, Factory_ extends AbstractScoreDirectorFactory<Solution_, Score_, Factory_>>
340-
extends AbstractScoreDirectorFactory<Solution_, Score_, Factory_> {
341+
private static class InternalScoreDirectorFactory<Solution_, Score_ extends Score<Score_>>
342+
extends AbstractScoreDirectorFactory<Solution_, Score_, InternalScoreDirectorFactory<Solution_, Score_>> {
341343

342344
public InternalScoreDirectorFactory(SolutionDescriptor<Solution_> solutionDescriptor) {
343345
super(solutionDescriptor);
@@ -349,12 +351,11 @@ public InternalScoreDirectorFactory(SolutionDescriptor<Solution_> solutionDescri
349351
}
350352
}
351353

352-
private static class InternalScoreDirector<Solution_, Score_ extends Score<Score_>, Factory_ extends AbstractScoreDirectorFactory<Solution_, Score_, Factory_>>
353-
extends AbstractScoreDirector<Solution_, Score_, Factory_> {
354+
private static class InternalScoreDirector<Solution_, Score_ extends Score<Score_>>
355+
extends AbstractScoreDirector<Solution_, Score_, InternalScoreDirectorFactory<Solution_, Score_>> {
354356

355357
public InternalScoreDirector(SolutionDescriptor<Solution_> solutionDescriptor) {
356-
super((Factory_) new InternalScoreDirectorFactory<Solution_, Score_, Factory_>(solutionDescriptor), false, DISABLED,
357-
false);
358+
super(new InternalScoreDirectorFactory<>(solutionDescriptor), false, DISABLED, false);
358359
}
359360

360361
@Override
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
package ai.timefold.solver.core.impl.domain.variable.declarative;
2+
3+
import java.util.BitSet;
4+
import java.util.List;
5+
import java.util.Objects;
6+
import java.util.PriorityQueue;
7+
import java.util.Set;
8+
import java.util.function.Consumer;
9+
import java.util.function.Function;
10+
11+
import ai.timefold.solver.core.impl.domain.variable.descriptor.VariableDescriptor;
12+
import ai.timefold.solver.core.impl.util.LinkedIdentityHashSet;
13+
14+
final class AffectedEntitiesUpdater<Solution_>
15+
implements Consumer<BitSet> {
16+
17+
// From WorkingReferenceGraph.
18+
private final BaseTopologicalOrderGraph graph;
19+
private final List<EntityVariablePair<Solution_>> instanceList; // Immutable.
20+
private final Function<Object, List<EntityVariablePair<Solution_>>> entityVariablePairFunction;
21+
private final ChangedVariableNotifier<Solution_> changedVariableNotifier;
22+
23+
// Internal state; expensive to create, therefore we reuse.
24+
private final AffectedEntities<Solution_> affectedEntities;
25+
private final LoopedTracker loopedTracker;
26+
private final BitSet visited;
27+
private final PriorityQueue<BaseTopologicalOrderGraph.NodeTopologicalOrder> changeQueue;
28+
29+
AffectedEntitiesUpdater(BaseTopologicalOrderGraph graph, List<EntityVariablePair<Solution_>> instanceList,
30+
Function<Object, List<EntityVariablePair<Solution_>>> entityVariablePairFunction,
31+
ChangedVariableNotifier<Solution_> changedVariableNotifier) {
32+
this.graph = graph;
33+
this.instanceList = instanceList;
34+
this.entityVariablePairFunction = entityVariablePairFunction;
35+
this.changedVariableNotifier = changedVariableNotifier;
36+
var instanceCount = instanceList.size();
37+
this.affectedEntities = new AffectedEntities<>(this::updateLoopedStatusOfAffectedEntity);
38+
this.loopedTracker = new LoopedTracker(instanceCount);
39+
this.visited = new BitSet(instanceCount);
40+
this.changeQueue = new PriorityQueue<>(instanceCount);
41+
}
42+
43+
@Override
44+
public void accept(BitSet changed) {
45+
initializeChangeQueue(changed);
46+
47+
while (!changeQueue.isEmpty()) {
48+
var nextNode = changeQueue.poll().nodeId();
49+
if (visited.get(nextNode)) {
50+
continue;
51+
}
52+
visited.set(nextNode);
53+
var shadowVariable = instanceList.get(nextNode);
54+
var isChanged = updateShadowVariable(shadowVariable, graph.isLooped(loopedTracker, nextNode));
55+
56+
if (isChanged) {
57+
var iterator = graph.nodeForwardEdges(nextNode);
58+
while (iterator.hasNext()) {
59+
var nextNodeForwardEdge = iterator.nextInt();
60+
if (!visited.get(nextNodeForwardEdge)) {
61+
changeQueue.add(graph.getTopologicalOrder(nextNodeForwardEdge));
62+
}
63+
}
64+
}
65+
}
66+
67+
affectedEntities.processAndClear();
68+
// Prepare for the next time updateChanged() is called.
69+
// No need to clear changeQueue, as that already finishes empty.
70+
loopedTracker.clear();
71+
visited.clear();
72+
}
73+
74+
private void initializeChangeQueue(BitSet changed) {
75+
// BitSet iteration: get the first set bit at or after 0,
76+
// then get the first set bit after that bit.
77+
// Iteration ends when nextSetBit returns -1.
78+
// This has the potential to overflow, since to do the
79+
// test, we necessarily need to do nextSetBit(i + 1),
80+
// and i + 1 can be negative if Integer.MAX_VALUE is set
81+
// in the BitSet.
82+
// This should never happen, since arrays in Java are limited
83+
// to slightly less than Integer.MAX_VALUE.
84+
for (var i = changed.nextSetBit(0); i >= 0; i = changed.nextSetBit(i + 1)) {
85+
changeQueue.add(graph.getTopologicalOrder(i));
86+
if (i == Integer.MAX_VALUE) {
87+
break; // or (i+1) would overflow
88+
}
89+
}
90+
changed.clear();
91+
}
92+
93+
private void updateLoopedStatusOfAffectedEntity(Object affectedEntity) {
94+
ShadowVariableLoopedVariableDescriptor<Solution_> shadowVariableLoopedDescriptor = null;
95+
var isEntityLooped = false;
96+
for (var node : entityVariablePairFunction.apply(affectedEntity)) {
97+
// All variables come from the same entity,
98+
// therefore all have the same looped marker.
99+
shadowVariableLoopedDescriptor = node.variableReference().shadowVariableLoopedDescriptor();
100+
if (graph.isLooped(loopedTracker, node.graphNodeId())) {
101+
isEntityLooped = true;
102+
break;
103+
}
104+
}
105+
if (shadowVariableLoopedDescriptor == null) {
106+
// At this point, affectedEntity is guaranteed to have looped marker.
107+
// Otherwise AffectedEntities would not have sent it here.
108+
throw new IllegalStateException("Impossible state: loop marker descriptor does not exist.");
109+
}
110+
var oldValue = shadowVariableLoopedDescriptor.getValue(affectedEntity);
111+
if (!Objects.equals(oldValue, isEntityLooped)) {
112+
changeShadowVariableAndNotify(shadowVariableLoopedDescriptor, affectedEntity, isEntityLooped);
113+
}
114+
115+
}
116+
117+
private boolean updateShadowVariable(EntityVariablePair<Solution_> entityVariable, boolean isLooped) {
118+
var entity = entityVariable.entity();
119+
var shadowVariableReference = entityVariable.variableReference();
120+
var oldValue = shadowVariableReference.memberAccessor().executeGetter(entity);
121+
122+
if (isLooped) {
123+
// null might be a valid value, and thus it could be the case
124+
// that is was not looped and null, then turned to looped and null,
125+
// which is still considered a change.
126+
affectedEntities.add(entityVariable);
127+
if (oldValue != null) {
128+
changeShadowVariableAndNotify(shadowVariableReference, entity, null);
129+
}
130+
return true;
131+
} else {
132+
var newValue = shadowVariableReference.calculator().apply(entity);
133+
if (!Objects.equals(oldValue, newValue)) {
134+
affectedEntities.add(entityVariable);
135+
changeShadowVariableAndNotify(shadowVariableReference, entity, newValue);
136+
return true;
137+
}
138+
}
139+
return false;
140+
}
141+
142+
private void changeShadowVariableAndNotify(VariableUpdaterInfo<Solution_> shadowVariableReference, Object entity,
143+
Object newValue) {
144+
var variableDescriptor = shadowVariableReference.variableDescriptor();
145+
changeShadowVariableAndNotify(variableDescriptor, entity, newValue);
146+
}
147+
148+
private void changeShadowVariableAndNotify(VariableDescriptor<Solution_> variableDescriptor, Object entity,
149+
Object newValue) {
150+
changedVariableNotifier.beforeVariableChanged().accept(variableDescriptor, entity);
151+
variableDescriptor.setValue(entity, newValue);
152+
changedVariableNotifier.afterVariableChanged().accept(variableDescriptor, entity);
153+
}
154+
155+
private static final class AffectedEntities<Solution_> {
156+
157+
private final Consumer<Object> consumer;
158+
private final Set<Object> entitiesForLoopedVarUpdateSet;
159+
160+
public AffectedEntities(Consumer<Object> consumer) {
161+
this.consumer = consumer;
162+
this.entitiesForLoopedVarUpdateSet = new LinkedIdentityHashSet<>();
163+
}
164+
165+
public void add(EntityVariablePair<Solution_> shadowVariable) {
166+
var shadowVariableLoopedDescriptor = shadowVariable.variableReference().shadowVariableLoopedDescriptor();
167+
if (shadowVariableLoopedDescriptor == null) {
168+
return;
169+
}
170+
entitiesForLoopedVarUpdateSet.add(shadowVariable.entity());
171+
}
172+
173+
public void processAndClear() {
174+
for (var entity : entitiesForLoopedVarUpdateSet) {
175+
consumer.accept(entity);
176+
}
177+
entitiesForLoopedVarUpdateSet.clear();
178+
}
179+
180+
}
181+
182+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package ai.timefold.solver.core.impl.domain.variable.declarative;
2+
3+
import java.util.PrimitiveIterator;
4+
5+
/**
6+
* Exists to expose read-only view of {@link TopologicalOrderGraph}.
7+
*/
8+
public interface BaseTopologicalOrderGraph {
9+
10+
/**
11+
* Return an iterator of the nodes that have the `from` node as a predecessor.
12+
*
13+
* @param from The predecessor node.
14+
* @return an iterator of nodes with from as a predecessor.
15+
*/
16+
PrimitiveIterator.OfInt nodeForwardEdges(int from);
17+
18+
/**
19+
* Returns true if a given node is in a strongly connected component with a size
20+
* greater than 1 (i.e. is in a loop) or is a transitive successor of a
21+
* node with the above property.
22+
*
23+
* @param loopedTracker a tracker that can be used to record looped state to avoid
24+
* recomputation.
25+
* @param node The node being queried
26+
* @return true if `node` is in a loop, false otherwise.
27+
*/
28+
boolean isLooped(LoopedTracker loopedTracker, int node);
29+
30+
/**
31+
* Returns a tuple containing node ID and a number corresponding to its topological order.
32+
* In particular, after {@link TopologicalOrderGraph#commitChanges()} is called, the following
33+
* must be true for any pair of nodes A, B where:
34+
* <ul>
35+
* <li>A is a predecessor of B</li>
36+
* <li>`isLooped(A) == isLooped(B) == false`</li>
37+
* </ul>
38+
* getTopologicalOrder(A) &lt; getTopologicalOrder(B)
39+
* <p>
40+
* Said number may not be unique.
41+
*/
42+
NodeTopologicalOrder getTopologicalOrder(int node);
43+
44+
/**
45+
* Stores a graph node id along its topological order.
46+
* Comparisons ignore node id and only use the topological order.
47+
* For instance, for x = (0, 0) and y = (1, 5), x is before y, whereas for
48+
* x = (0, 5) and y = (1, 0), y is before x. Note {@link BaseTopologicalOrderGraph}
49+
* is not guaranteed to return every topological order index (i.e.
50+
* it might be the case no nodes has order 0).
51+
*/
52+
record NodeTopologicalOrder(int nodeId, int order)
53+
implements
54+
Comparable<NodeTopologicalOrder> {
55+
56+
@Override
57+
public int compareTo(NodeTopologicalOrder other) {
58+
return order - other.order;
59+
}
60+
61+
@Override
62+
public boolean equals(Object o) {
63+
if (o instanceof NodeTopologicalOrder other) {
64+
return nodeId == other.nodeId;
65+
}
66+
return false;
67+
}
68+
69+
@Override
70+
public int hashCode() {
71+
return nodeId;
72+
}
73+
74+
}
75+
76+
}

core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSession.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
@NullMarked
99
public class DefaultShadowVariableSession<Solution_> implements Supply {
10+
1011
final VariableReferenceGraph<Solution_> graph;
1112

1213
public DefaultShadowVariableSession(VariableReferenceGraph<Solution_> graph) {

0 commit comments

Comments
 (0)