Skip to content

Commit ece6d14

Browse files
chore: split random at the start of step (TimefoldAI#2259)
Done so that multi-threaded scores generally do not get worse than single-threaded.
1 parent 609e361 commit ece6d14

23 files changed

Lines changed: 281 additions & 327 deletions

File tree

core/src/main/java/ai/timefold/solver/core/config/solver/SolverConfig.java

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,13 @@
4141
import ai.timefold.solver.core.config.score.director.ScoreDirectorFactoryConfig;
4242
import ai.timefold.solver.core.config.solver.monitoring.MonitoringConfig;
4343
import ai.timefold.solver.core.config.solver.monitoring.SolverMetric;
44-
import ai.timefold.solver.core.config.solver.random.RandomType;
4544
import ai.timefold.solver.core.config.solver.termination.TerminationConfig;
4645
import ai.timefold.solver.core.config.util.ConfigUtils;
4746
import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor;
4847
import ai.timefold.solver.core.impl.heuristic.selector.common.nearby.NearbyDistanceMeter;
4948
import ai.timefold.solver.core.impl.io.jaxb.SolverConfigIO;
5049
import ai.timefold.solver.core.impl.io.jaxb.TimefoldXmlSerializationException;
5150
import ai.timefold.solver.core.impl.phase.PhaseFactory;
52-
import ai.timefold.solver.core.impl.solver.random.RandomFactory;
5351

5452
import org.jspecify.annotations.NonNull;
5553
import org.jspecify.annotations.Nullable;
@@ -63,9 +61,7 @@
6361
"enablePreviewFeatureSet",
6462
"environmentMode",
6563
"daemon",
66-
"randomType",
6764
"randomSeed",
68-
"randomFactoryClass",
6965
"moveThreadCount",
7066
"moveThreadBufferSize",
7167
"threadFactoryClass",
@@ -215,9 +211,7 @@ public final class SolverConfig extends AbstractConfig<SolverConfig> {
215211
private Set<PreviewFeature> enablePreviewFeatureSet = null;
216212
private EnvironmentMode environmentMode = null;
217213
private Boolean daemon = null;
218-
private RandomType randomType = null;
219214
private Long randomSeed = null;
220-
private Class<? extends RandomFactory> randomFactoryClass = null;
221215
private String moveThreadCount = null;
222216
private Integer moveThreadBufferSize = null;
223217
private Class<? extends ThreadFactory> threadFactoryClass = null;
@@ -332,14 +326,6 @@ public void setDaemon(@Nullable Boolean daemon) {
332326
this.daemon = daemon;
333327
}
334328

335-
public @Nullable RandomType getRandomType() {
336-
return randomType;
337-
}
338-
339-
public void setRandomType(@Nullable RandomType randomType) {
340-
this.randomType = randomType;
341-
}
342-
343329
public @Nullable Long getRandomSeed() {
344330
return randomSeed;
345331
}
@@ -348,14 +334,6 @@ public void setRandomSeed(@Nullable Long randomSeed) {
348334
this.randomSeed = randomSeed;
349335
}
350336

351-
public @Nullable Class<? extends RandomFactory> getRandomFactoryClass() {
352-
return randomFactoryClass;
353-
}
354-
355-
public void setRandomFactoryClass(@Nullable Class<? extends RandomFactory> randomFactoryClass) {
356-
this.randomFactoryClass = randomFactoryClass;
357-
}
358-
359337
public @Nullable String getMoveThreadCount() {
360338
return moveThreadCount;
361339
}
@@ -471,21 +449,11 @@ public void setMonitoringConfig(@Nullable MonitoringConfig monitoringConfig) {
471449
return this;
472450
}
473451

474-
public @NonNull SolverConfig withRandomType(@NonNull RandomType randomType) {
475-
this.randomType = randomType;
476-
return this;
477-
}
478-
479452
public @NonNull SolverConfig withRandomSeed(@NonNull Long randomSeed) {
480453
this.randomSeed = randomSeed;
481454
return this;
482455
}
483456

484-
public @NonNull SolverConfig withRandomFactoryClass(@NonNull Class<? extends RandomFactory> randomFactoryClass) {
485-
this.randomFactoryClass = randomFactoryClass;
486-
return this;
487-
}
488-
489457
public @NonNull SolverConfig withMoveThreadCount(@NonNull String moveThreadCount) {
490458
this.moveThreadCount = moveThreadCount;
491459
return this;
@@ -648,7 +616,7 @@ public boolean canTerminate() {
648616
// ************************************************************************
649617

650618
public void offerRandomSeedFromSubSingleIndex(long subSingleIndex) {
651-
if ((environmentMode == null || environmentMode.isReproducible()) && randomFactoryClass == null && randomSeed == null) {
619+
if ((environmentMode == null || environmentMode.isReproducible()) && randomSeed == null) {
652620
randomSeed = subSingleIndex;
653621
}
654622
}
@@ -665,10 +633,7 @@ public void offerRandomSeedFromSubSingleIndex(long subSingleIndex) {
665633
inheritedConfig.getEnablePreviewFeatureSet());
666634
environmentMode = ConfigUtils.inheritOverwritableProperty(environmentMode, inheritedConfig.getEnvironmentMode());
667635
daemon = ConfigUtils.inheritOverwritableProperty(daemon, inheritedConfig.getDaemon());
668-
randomType = ConfigUtils.inheritOverwritableProperty(randomType, inheritedConfig.getRandomType());
669636
randomSeed = ConfigUtils.inheritOverwritableProperty(randomSeed, inheritedConfig.getRandomSeed());
670-
randomFactoryClass = ConfigUtils.inheritOverwritableProperty(randomFactoryClass,
671-
inheritedConfig.getRandomFactoryClass());
672637
moveThreadCount = ConfigUtils.inheritOverwritableProperty(moveThreadCount,
673638
inheritedConfig.getMoveThreadCount());
674639
moveThreadBufferSize = ConfigUtils.inheritOverwritableProperty(moveThreadBufferSize,
@@ -700,7 +665,6 @@ public void offerRandomSeedFromSubSingleIndex(long subSingleIndex) {
700665

701666
@Override
702667
public void visitReferencedClasses(@NonNull Consumer<Class<?>> classVisitor) {
703-
classVisitor.accept(randomFactoryClass);
704668
classVisitor.accept(threadFactoryClass);
705669
classVisitor.accept(solutionClass);
706670
if (entityClassList != null) {

core/src/main/java/ai/timefold/solver/core/config/solver/random/RandomType.java

Lines changed: 0 additions & 23 deletions
This file was deleted.

core/src/main/java/ai/timefold/solver/core/config/solver/random/package-info.java

Lines changed: 0 additions & 9 deletions
This file was deleted.

core/src/main/java/ai/timefold/solver/core/impl/constructionheuristic/DefaultConstructionHeuristicPhase.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ public void solve(SolverScope<Solution_> solverScope) {
104104
stepScope.getPhaseScope().calculateSolverTimeMillisSpentUpToNow());
105105
}
106106
} else {
107-
throw new IllegalStateException("The step index (" + stepScope.getStepIndex()
108-
+ ") has selected move count (" + stepScope.getSelectedMoveCount()
109-
+ ") but failed to pick a nextStep (" + stepScope.getStep() + ").");
107+
throw new IllegalStateException(
108+
"The step index (%d) has selected move count (%d) but failed to pick a nextStep (%s).".formatted(
109+
stepScope.getStepIndex(), stepScope.getSelectedMoveCount(), stepScope.getStep()));
110110
}
111111
// Although stepStarted has been called, stepEnded is not called for this step.
112112
earlyTerminationStatus = TerminationStatus.early(phaseScope.getNextStepIndex());
@@ -190,11 +190,16 @@ public void phaseEnded(ConstructionHeuristicPhaseScope<Solution_> phaseScope) {
190190
phaseScope.endingNow();
191191
if (decider.isLoggingEnabled() && logger.isInfoEnabled()) {
192192
logger.info(
193-
"{}Construction Heuristic phase ({}) ended: time spent ({}), best score ({}), move evaluation speed ({}/sec), step total ({}).",
193+
"""
194+
{}Construction Heuristic phase ({}) ended: time spent ({}), best score ({}), \
195+
{}move evaluation speed ({}/sec), step total ({}).""",
194196
logIndentation,
195197
phaseIndex,
196198
phaseScope.calculateSolverTimeMillisSpentUpToNow(),
197199
phaseScope.getBestScore().raw(),
200+
// Multithreaded solving uses "effective" move evaluation speed, since not all evaluated moves
201+
// are foraged
202+
(decider.getClass().equals(ConstructionHeuristicDecider.class)) ? "" : "effective ",
198203
phaseScope.getPhaseMoveEvaluationSpeed(),
199204
phaseScope.getNextStepIndex());
200205
}

core/src/main/java/ai/timefold/solver/core/impl/exhaustivesearch/DefaultExhaustiveSearchPhase.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ private void phaseEnded(ExhaustiveSearchPhaseScope<Solution_> phaseScope) {
100100
super.phaseEnded(phaseScope);
101101
decider.phaseEnded(phaseScope);
102102
phaseScope.endingNow();
103-
logger.info("{}Exhaustive Search phase ({}) ended: time spent ({}), best score ({}),"
104-
+ " move evaluation speed ({}/sec), step total ({}).",
103+
logger.info("""
104+
{}Exhaustive Search phase ({}) ended: time spent ({}), best score ({}),\
105+
move evaluation speed ({}/sec), step total ({}).""",
105106
logIndentation,
106107
phaseIndex,
107108
phaseScope.calculateSolverTimeMillisSpentUpToNow(),

core/src/main/java/ai/timefold/solver/core/impl/localsearch/DefaultLocalSearchPhase.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,17 @@ public void solve(SolverScope<Solution_> solverScope) {
9494
stepScope.getStepIndex(),
9595
stepScope.getPhaseScope().calculateSolverTimeMillisSpentUpToNow());
9696
} else if (stepScope.getSelectedMoveCount() == 0L) {
97-
logger.warn("{} No doable selected move at step index ({}), time spent ({})."
98-
+ " Terminating phase early.",
97+
logger.warn("""
98+
{} No doable selected move at step index ({}), time spent ({}). \
99+
Terminating phase early.""",
99100
logIndentation,
100101
stepScope.getStepIndex(),
101102
stepScope.getPhaseScope().calculateSolverTimeMillisSpentUpToNow());
102103
} else {
103-
throw new IllegalStateException("The step index (" + stepScope.getStepIndex()
104-
+ ") has accepted/selected move count (" + stepScope.getAcceptedMoveCount() + "/"
105-
+ stepScope.getSelectedMoveCount()
106-
+ ") but failed to pick a nextStep (" + stepScope.getStep() + ").");
104+
throw new IllegalStateException(
105+
"The step index (%d) has accepted/selected move count (%d/%d) but failed to pick a nextStep (%s)."
106+
.formatted(stepScope.getStepIndex(), stepScope.getAcceptedMoveCount(),
107+
stepScope.getSelectedMoveCount(), stepScope.getStep()));
107108
}
108109
// Although stepStarted has been called, stepEnded is not called for this step
109110
break;
@@ -151,17 +152,19 @@ public void stepEnded(LocalSearchStepScope<Solution_> stepScope) {
151152
if (logger.isDebugEnabled()) {
152153
if (stepScope.getAcceptedMoveCount() == 0 && phaseTermination.isPhaseTerminated(phaseScope)) {
153154
// Terminated early
154-
logger.debug("{} LS step ({}), time spent ({}), score ({}), {} best score ({})," +
155-
" terminated prematurely after selecting {} moves.",
155+
logger.debug("""
156+
{} LS step ({}), time spent ({}), score ({}), {} best score ({}), \
157+
terminated prematurely after selecting {} moves.""",
156158
logIndentation,
157159
stepScope.getStepIndex(),
158160
phaseScope.calculateSolverTimeMillisSpentUpToNow(),
159161
stepScope.getScore().raw(),
160162
(stepScope.getBestScoreImproved() ? "new" : " "), phaseScope.getBestScore().raw(),
161163
stepScope.getSelectedMoveCount());
162164
} else {
163-
logger.debug("{} LS step ({}), time spent ({}), score ({}), {} best score ({})," +
164-
" accepted/selected move count ({}/{}), picked move ({}).",
165+
logger.debug("""
166+
{} LS step ({}), time spent ({}), score ({}), {} best score ({}), \
167+
accepted/selected move count ({}/{}), picked move ({}).""",
165168
logIndentation,
166169
stepScope.getStepIndex(),
167170
phaseScope.calculateSolverTimeMillisSpentUpToNow(),
@@ -224,12 +227,16 @@ public void phaseEnded(LocalSearchPhaseScope<Solution_> phaseScope) {
224227
super.phaseEnded(phaseScope);
225228
decider.phaseEnded(phaseScope);
226229
phaseScope.endingNow();
227-
logger.info("{}Local Search phase ({}) ended: time spent ({}), best score ({}),"
228-
+ " move evaluation speed ({}/sec), step total ({}).",
230+
logger.info("""
231+
{}Local Search phase ({}) ended: time spent ({}), best score ({}), \
232+
{}move evaluation speed ({}/sec), step total ({}).""",
229233
logIndentation,
230234
phaseIndex,
231235
phaseScope.calculateSolverTimeMillisSpentUpToNow(),
232236
phaseScope.getBestScore().raw(),
237+
// Multithreaded solving uses "effective" move evaluation speed, since not all evaluated moves
238+
// are foraged
239+
(decider.getClass().equals(LocalSearchDecider.class)) ? "" : "effective ",
233240
phaseScope.getPhaseMoveEvaluationSpeed(),
234241
phaseScope.getNextStepIndex());
235242
}

core/src/main/java/ai/timefold/solver/core/impl/solver/AbstractSolver.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import java.util.Iterator;
44
import java.util.List;
5+
import java.util.Objects;
6+
import java.util.random.RandomGenerator;
57

68
import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
79
import ai.timefold.solver.core.api.solver.Solver;
@@ -12,11 +14,13 @@
1214
import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope;
1315
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
1416
import ai.timefold.solver.core.impl.solver.event.SolverEventSupport;
17+
import ai.timefold.solver.core.impl.solver.random.DelegatingSplittableRandomGenerator;
1518
import ai.timefold.solver.core.impl.solver.recaller.BestSolutionRecaller;
1619
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
1720
import ai.timefold.solver.core.impl.solver.termination.UniversalTermination;
1821

1922
import org.jspecify.annotations.NullMarked;
23+
import org.jspecify.annotations.Nullable;
2024
import org.slf4j.Logger;
2125
import org.slf4j.LoggerFactory;
2226

@@ -44,6 +48,8 @@ public abstract class AbstractSolver<Solution_> implements Solver<Solution_> {
4448
protected final UniversalTermination<Solution_> globalTermination;
4549
protected final List<Phase<Solution_>> phaseList;
4650

51+
private RandomGenerator.@Nullable SplittableGenerator savedRandom;
52+
4753
// ************************************************************************
4854
// Constructors and simple getters/setters
4955
// ************************************************************************
@@ -123,10 +129,18 @@ public void stepStarted(AbstractStepScope<Solution_> stepScope) {
123129
bestSolutionRecaller.stepStarted(stepScope);
124130
phaseLifecycleSupport.fireStepStarted(stepScope);
125131
globalTermination.stepStarted(stepScope);
132+
// To ensure reproducibility even when the number of random calls is not deterministic,
133+
// split the random at step start.
134+
var delegatingRandom = ((DelegatingSplittableRandomGenerator) stepScope.getWorkingRandom());
135+
savedRandom = delegatingRandom.getDelegate();
136+
delegatingRandom.setDelegate(delegatingRandom.split());
126137
// Do not propagate to phases; the active phase does that for itself and they should not propagate further.
127138
}
128139

129140
public void stepEnded(AbstractStepScope<Solution_> stepScope) {
141+
// Restore from the split random
142+
var delegatingRandom = ((DelegatingSplittableRandomGenerator) stepScope.getWorkingRandom());
143+
delegatingRandom.setDelegate(Objects.requireNonNull(savedRandom));
130144
bestSolutionRecaller.stepEnded(stepScope);
131145
phaseLifecycleSupport.fireStepEnded(stepScope);
132146
globalTermination.stepEnded(stepScope);

core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolver.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import java.util.Map;
66
import java.util.Objects;
77
import java.util.concurrent.atomic.AtomicBoolean;
8+
import java.util.function.Supplier;
9+
import java.util.random.RandomGenerator;
810

911
import ai.timefold.solver.core.api.domain.common.PlanningId;
1012
import ai.timefold.solver.core.api.domain.solution.PlanningSolution;
@@ -16,7 +18,6 @@
1618
import ai.timefold.solver.core.impl.phase.Phase;
1719
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
1820
import ai.timefold.solver.core.impl.score.director.ScoreDirectorFactory;
19-
import ai.timefold.solver.core.impl.solver.random.RandomFactory;
2021
import ai.timefold.solver.core.impl.solver.recaller.BestSolutionRecaller;
2122
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
2223
import ai.timefold.solver.core.impl.solver.termination.BasicPlumbingTermination;
@@ -38,7 +39,7 @@
3839
public class DefaultSolver<Solution_> extends AbstractSolver<Solution_> {
3940

4041
protected final EnvironmentMode environmentMode;
41-
protected final RandomFactory randomFactory;
42+
protected final Supplier<RandomGenerator> randomFactory;
4243
protected final BasicPlumbingTermination<Solution_> basicPlumbingTermination;
4344
protected final AtomicBoolean solving = new AtomicBoolean(false);
4445
protected final SolverScope<Solution_> solverScope;
@@ -48,7 +49,7 @@ public class DefaultSolver<Solution_> extends AbstractSolver<Solution_> {
4849
// Constructors and simple getters/setters
4950
// ************************************************************************
5051

51-
public DefaultSolver(EnvironmentMode environmentMode, RandomFactory randomFactory,
52+
public DefaultSolver(EnvironmentMode environmentMode, Supplier<RandomGenerator> randomFactory,
5253
BestSolutionRecaller<Solution_> bestSolutionRecaller, BasicPlumbingTermination<Solution_> basicPlumbingTermination,
5354
UniversalTermination<Solution_> termination, List<Phase<Solution_>> phaseList,
5455
SolverScope<Solution_> solverScope, String moveThreadCountDescription) {
@@ -65,8 +66,8 @@ public EnvironmentMode getEnvironmentMode() {
6566
return environmentMode;
6667
}
6768

68-
public RandomFactory getRandomFactory() {
69-
return randomFactory;
69+
public RandomGenerator getRandomGenerator() {
70+
return randomFactory.get();
7071
}
7172

7273
public ScoreDirectorFactory<Solution_, ?> getScoreDirectorFactory() {
@@ -187,7 +188,7 @@ public void outerSolvingStarted(SolverScope<Solution_> solverScope) {
187188
solving.set(true);
188189
basicPlumbingTermination.resetTerminateEarly();
189190
solverScope.setStartingSolverCount(0);
190-
solverScope.setWorkingRandom(randomFactory.createRandom());
191+
solverScope.setWorkingRandom(randomFactory.get());
191192
}
192193

193194
@Override

0 commit comments

Comments
 (0)