Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import ai.timefold.solver.core.impl.heuristic.selector.move.AbstractMoveSelector;
import ai.timefold.solver.core.impl.heuristic.selector.move.MoveSelector;
import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope;
import ai.timefold.solver.core.impl.solver.termination.PhaseTermination;

public final class FilteringMoveSelector<Solution_> extends AbstractMoveSelector<Solution_> {

private static final long BAIL_OUT_MULTIPLIER = 10L;

public static <Solution_> FilteringMoveSelector<Solution_> of(MoveSelector<Solution_> moveSelector,
SelectionFilter<Solution_, Move<Solution_>> filter) {
if (moveSelector instanceof FilteringMoveSelector<Solution_> filteringMoveSelector) {
Expand All @@ -24,6 +27,7 @@ public static <Solution_> FilteringMoveSelector<Solution_> of(MoveSelector<Solut
private final MoveSelector<Solution_> childMoveSelector;
private final SelectionFilter<Solution_, Move<Solution_>> filter;
private final boolean bailOutEnabled;
private AbstractPhaseScope<Solution_> phaseScope;

private ScoreDirector<Solution_> scoreDirector = null;

Expand All @@ -42,13 +46,15 @@ private FilteringMoveSelector(MoveSelector<Solution_> childMoveSelector,
@Override
public void phaseStarted(AbstractPhaseScope<Solution_> phaseScope) {
super.phaseStarted(phaseScope);
scoreDirector = phaseScope.getScoreDirector();
this.scoreDirector = phaseScope.getScoreDirector();
this.phaseScope = phaseScope;
}

@Override
public void phaseEnded(AbstractPhaseScope<Solution_> phaseScope) {
super.phaseEnded(phaseScope);
scoreDirector = null;
this.scoreDirector = null;
this.phaseScope = null;
}

@Override
Expand All @@ -68,23 +74,53 @@ public long getSize() {

@Override
public Iterator<Move<Solution_>> iterator() {
return new JustInTimeFilteringMoveIterator(childMoveSelector.iterator(), determineBailOutSize());
return new JustInTimeFilteringMoveIterator(childMoveSelector.iterator(), determineBailOutSize(), phaseScope);
}

private long determineBailOutSize() {
if (!bailOutEnabled) {
return -1L;
}
try {
return childMoveSelector.getSize() * BAIL_OUT_MULTIPLIER;
} catch (Exception ex) {
// Some move selectors throw an exception when getSize() is called.
// In this case, we choose to disregard it and pick a large-enough bail-out size anyway.
// The ${bailOutSize+1}th move could in theory show up where previous ${bailOutSize} moves did not,
// but we consider this to be an acceptable risk,
// outweighed by the benefit of the solver never running into an endless loop.
// The exception itself is swallowed, as it doesn't bring any useful information.
long bailOutSize = Short.MAX_VALUE * BAIL_OUT_MULTIPLIER;
logger.trace(
" Never-ending move selector ({}) failed to provide size, choosing a bail-out size of ({}) attempts.",
childMoveSelector, bailOutSize);
return bailOutSize;
}
}

private class JustInTimeFilteringMoveIterator extends UpcomingSelectionIterator<Move<Solution_>> {

private final long TERMINATION_BAIL_OUT_SIZE = 1000L;
private final Iterator<Move<Solution_>> childMoveIterator;
private final long bailOutSize;
private final AbstractPhaseScope<Solution_> phaseScope;
private final PhaseTermination<Solution_> termination;

public JustInTimeFilteringMoveIterator(Iterator<Move<Solution_>> childMoveIterator, long bailOutSize) {
public JustInTimeFilteringMoveIterator(Iterator<Move<Solution_>> childMoveIterator, long bailOutSize,
AbstractPhaseScope<Solution_> phaseScope) {
this.childMoveIterator = childMoveIterator;
this.bailOutSize = bailOutSize;
this.phaseScope = phaseScope;
this.termination = phaseScope != null ? phaseScope.getTermination() : null;
}

@Override
protected Move<Solution_> createUpcomingSelection() {
Move<Solution_> next;
long attemptsBeforeBailOut = bailOutSize;
// To reduce the impact of checking for termination on each move,
// we only check for termination after filtering out 1000 moves.
long attemptsBeforeCheckTermination = TERMINATION_BAIL_OUT_SIZE;
do {
if (!childMoveIterator.hasNext()) {
return noUpcomingSelection();
Expand All @@ -95,8 +131,18 @@ protected Move<Solution_> createUpcomingSelection() {
logger.trace("Bailing out of neverEnding selector ({}) after ({}) attempts to avoid infinite loop.",
FilteringMoveSelector.this, bailOutSize);
return noUpcomingSelection();
} else if (termination != null && attemptsBeforeCheckTermination <= 0L) {
// Reset the counter
attemptsBeforeCheckTermination = TERMINATION_BAIL_OUT_SIZE;
if (termination.isPhaseTerminated(phaseScope)) {
logger.trace(
"Bailing out of neverEnding selector ({}) because the termination setting has been triggered.",
FilteringMoveSelector.this);
return noUpcomingSelection();
}
}
attemptsBeforeBailOut--;
attemptsBeforeCheckTermination--;
}
next = childMoveIterator.next();
} while (!accept(scoreDirector, next));
Expand All @@ -105,34 +151,12 @@ protected Move<Solution_> createUpcomingSelection() {

}

private long determineBailOutSize() {
if (!bailOutEnabled) {
return -1L;
}
try {
return childMoveSelector.getSize() * 10L;
} catch (Exception ex) {
// Some move selectors throw an exception when getSize() is called.
// In this case, we choose to disregard it and pick a large-enough bail-out size anyway.
// The ${bailOutSize+1}th move could in theory show up where previous ${bailOutSize} moves did not,
// but we consider this to be an acceptable risk,
// outweighed by the benefit of the solver never running into an endless loop.
// The exception itself is swallowed, as it doesn't bring any useful information.
long bailOutSize = Short.MAX_VALUE * 10L;
logger.trace(
" Never-ending move selector ({}) failed to provide size, choosing a bail-out size of ({}) attempts.",
childMoveSelector, bailOutSize);
return bailOutSize;
}
}

private boolean accept(ScoreDirector<Solution_> scoreDirector, Move<Solution_> move) {
if (filter != null) {
if (!filter.accept(scoreDirector, move)) {
logger.trace(" Move ({}) filtered out by a selection filter ({}).", move, filter);
return false;
}
if (filter != null && !filter.accept(scoreDirector, move)) {
logger.trace(" Move ({}) filtered out by a selection filter ({}).", move, filter);
return false;
}

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public void phaseStarted(AbstractPhaseScope<Solution_> phaseScope) {
solver.phaseStarted(phaseScope);
}
phaseTermination.phaseStarted(phaseScope);
phaseScope.setTermination(phaseTermination);
phaseLifecycleSupport.firePhaseStarted(phaseScope);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ai.timefold.solver.core.impl.score.director.InnerScore;
import ai.timefold.solver.core.impl.score.director.InnerScoreDirector;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
import ai.timefold.solver.core.impl.solver.termination.PhaseTermination;
import ai.timefold.solver.core.preview.api.move.Move;

import org.slf4j.Logger;
Expand Down Expand Up @@ -36,6 +37,11 @@ public abstract class AbstractPhaseScope<Solution_> {

protected int bestSolutionStepIndex;

/**
* The phase termination configuration
*/
private PhaseTermination<Solution_> termination;

/**
* As defined by #AbstractPhaseScope(SolverScope, int, boolean)
* with the phaseSendingBestSolutionEvents parameter set to true.
Expand Down Expand Up @@ -188,6 +194,14 @@ public <Score_ extends Score<Score_>> InnerScoreDirector<Solution_, Score_> getS
return solverScope.getScoreDirector();
}

public void setTermination(PhaseTermination<Solution_> termination) {
this.termination = termination;
}

public PhaseTermination<Solution_> getTermination() {
return termination;
}

public Solution_ getWorkingSolution() {
return solverScope.getWorkingSolution();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import static ai.timefold.solver.core.testutil.PlannerAssert.assertAllCodesOfMoveSelector;
import static ai.timefold.solver.core.testutil.PlannerAssert.verifyPhaseLifecycle;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.Iterator;

import ai.timefold.solver.core.config.heuristic.selector.common.SelectionCacheType;
import ai.timefold.solver.core.impl.heuristic.move.DummyMove;
import ai.timefold.solver.core.impl.heuristic.selector.SelectorTestUtils;
Expand All @@ -15,6 +19,7 @@
import ai.timefold.solver.core.impl.phase.scope.AbstractPhaseScope;
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
import ai.timefold.solver.core.impl.solver.scope.SolverScope;
import ai.timefold.solver.core.impl.solver.termination.BasicPlumbingTermination;
import ai.timefold.solver.core.testdomain.TestdataSolution;

import org.junit.jupiter.api.Test;
Expand All @@ -41,6 +46,25 @@ void filterCacheTypeJustInTime() {
filter(SelectionCacheType.JUST_IN_TIME, 5);
}

@Test
void bailOutByTermination() {
var phaseScope = mock(AbstractPhaseScope.class);
var moveSelector = mock(MoveSelector.class);
var termination = mock(BasicPlumbingTermination.class);
var iterator = mock(Iterator.class);
when(moveSelector.getSize()).thenReturn(1000L);
when(moveSelector.isNeverEnding()).thenReturn(true);
when(moveSelector.iterator()).thenReturn(iterator);
when(iterator.hasNext()).thenReturn(true);
when(phaseScope.getTermination()).thenReturn(termination);
when(termination.isPhaseTerminated(any(AbstractPhaseScope.class))).thenReturn(false, true);
var filteredMoveSelector = FilteringMoveSelector.of(moveSelector, (scoreDirector, selection) -> false);
filteredMoveSelector.phaseStarted(phaseScope);
assertThat(filteredMoveSelector.iterator().hasNext()).isFalse();
// The termination returns true at the second call, and 2000 calls are executed in total
verify(iterator, times(2000)).next();
}

public void filter(SelectionCacheType cacheType, int timesCalled) {
MoveSelector childMoveSelector = SelectorTestUtils.mockMoveSelector(
new DummyMove("a1"), new DummyMove("a2"), new DummyMove("a3"), new DummyMove("a4"));
Expand Down
Loading