diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/ShadowVariableUpdateHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/ShadowVariableUpdateHelper.java index fe4b7787881..f2c8d6e72e6 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/ShadowVariableUpdateHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/ShadowVariableUpdateHelper.java @@ -13,7 +13,6 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Map; -import java.util.Set; import ai.timefold.solver.core.api.domain.entity.PlanningEntity; import ai.timefold.solver.core.api.domain.variable.CascadingUpdateShadowVariable; @@ -34,6 +33,7 @@ import ai.timefold.solver.core.impl.domain.variable.declarative.DefaultShadowVariableSessionFactory; import ai.timefold.solver.core.impl.domain.variable.declarative.DefaultTopologicalOrderGraph; import ai.timefold.solver.core.impl.domain.variable.declarative.VariableReferenceGraph; +import ai.timefold.solver.core.impl.domain.variable.declarative.VariableReferenceGraphBuilder; import ai.timefold.solver.core.impl.domain.variable.descriptor.BasicVariableDescriptor; import ai.timefold.solver.core.impl.domain.variable.descriptor.ShadowVariableDescriptor; import ai.timefold.solver.core.impl.domain.variable.index.IndexShadowVariableDescriptor; @@ -72,18 +72,19 @@ private ShadowVariableUpdateHelper(EnumSet supportedShadowVa this.supportedShadowVariableTypes = supportedShadowVariableTypes; } + @SuppressWarnings("unchecked") public void updateShadowVariables(Solution_ solution) { - var initialSolutionDescriptor = (SolutionDescriptor) SolutionDescriptor.buildSolutionDescriptor( - Set.of(PreviewFeature.DECLARATIVE_SHADOW_VARIABLES), - solution.getClass()); - var entityClassList = initialSolutionDescriptor.getAllEntitiesAndProblemFacts(solution) + var enabledPreviewFeatures = EnumSet.of(PreviewFeature.DECLARATIVE_SHADOW_VARIABLES); + var solutionClass = (Class) solution.getClass(); + var initialSolutionDescriptor = SolutionDescriptor.buildSolutionDescriptor( + enabledPreviewFeatures, solutionClass); + var entityClassArray = initialSolutionDescriptor.getAllEntitiesAndProblemFacts(solution) .stream() .map(Object::getClass) .distinct() - .toList(); - var solutionDescriptor = (SolutionDescriptor) SolutionDescriptor.buildSolutionDescriptor( - Set.of(PreviewFeature.DECLARATIVE_SHADOW_VARIABLES), - solution.getClass(), entityClassList.toArray(Class[]::new)); + .toArray(Class[]::new); + var solutionDescriptor = SolutionDescriptor.buildSolutionDescriptor(enabledPreviewFeatures, solutionClass, + entityClassArray); try (var scoreDirector = new InternalScoreDirector<>(solutionDescriptor)) { // When we have a solution, we can reuse the logic from VariableListenerSupport to update all variable types scoreDirector.setWorkingSolution(solution); @@ -117,9 +118,8 @@ public void updateShadowVariables(Class solutionClass, .formatted(missingShadowVariableTypeList)); } // No solution, we trigger all supported events manually - var session = new InternalShadowVariableSession<>(solutionDescriptor, - new VariableReferenceGraph<>(ChangedVariableNotifier.empty())); - session.init(entities); + var session = InternalShadowVariableSession.build(solutionDescriptor, + new VariableReferenceGraphBuilder<>(ChangedVariableNotifier.empty()), entities); // Update all built-in shadow variables var listVariableDescriptor = solutionDescriptor.getListVariableDescriptor(); if (listVariableDescriptor == null) { @@ -135,11 +135,12 @@ public void updateShadowVariables(Class solutionClass, private record InternalShadowVariableSession(SolutionDescriptor solutionDescriptor, VariableReferenceGraph graph) { - public void init(Object... entities) { - if (!solutionDescriptor.getDeclarativeShadowVariableDescriptors().isEmpty()) { - DefaultShadowVariableSessionFactory.visitGraph(solutionDescriptor, graph, entities, - DefaultTopologicalOrderGraph::new); - } + public static InternalShadowVariableSession build( + SolutionDescriptor solutionDescriptor, VariableReferenceGraphBuilder graph, + Object... entities) { + return new InternalShadowVariableSession<>(solutionDescriptor, + DefaultShadowVariableSessionFactory.buildGraph(solutionDescriptor, graph, entities, + DefaultTopologicalOrderGraph::new)); } /** @@ -249,6 +250,7 @@ public void processListVariable(Object... entities) { * * @param entities the entities to be analyzed */ + @SuppressWarnings("unchecked") public void processCascadingVariable(Object... entities) { var listVariableDescriptor = solutionDescriptor.getListVariableDescriptor(); if (listVariableDescriptor != null) { @@ -336,8 +338,8 @@ private List> fetchBasicDescriptors(EntityDes } } - private static class InternalScoreDirectorFactory, Factory_ extends AbstractScoreDirectorFactory> - extends AbstractScoreDirectorFactory { + private static class InternalScoreDirectorFactory> + extends AbstractScoreDirectorFactory> { public InternalScoreDirectorFactory(SolutionDescriptor solutionDescriptor) { super(solutionDescriptor); @@ -349,12 +351,11 @@ public InternalScoreDirectorFactory(SolutionDescriptor solutionDescri } } - private static class InternalScoreDirector, Factory_ extends AbstractScoreDirectorFactory> - extends AbstractScoreDirector { + private static class InternalScoreDirector> + extends AbstractScoreDirector> { public InternalScoreDirector(SolutionDescriptor solutionDescriptor) { - super((Factory_) new InternalScoreDirectorFactory(solutionDescriptor), false, DISABLED, - false); + super(new InternalScoreDirectorFactory<>(solutionDescriptor), false, DISABLED, false); } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/AffectedEntitiesUpdater.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/AffectedEntitiesUpdater.java new file mode 100644 index 00000000000..4d8b4c94ceb --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/AffectedEntitiesUpdater.java @@ -0,0 +1,182 @@ +package ai.timefold.solver.core.impl.domain.variable.declarative; + +import java.util.BitSet; +import java.util.List; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.function.Consumer; +import java.util.function.Function; + +import ai.timefold.solver.core.impl.domain.variable.descriptor.VariableDescriptor; +import ai.timefold.solver.core.impl.util.LinkedIdentityHashSet; + +final class AffectedEntitiesUpdater + implements Consumer { + + // From WorkingReferenceGraph. + private final BaseTopologicalOrderGraph graph; + private final List> instanceList; // Immutable. + private final Function>> entityVariablePairFunction; + private final ChangedVariableNotifier changedVariableNotifier; + + // Internal state; expensive to create, therefore we reuse. + private final AffectedEntities affectedEntities; + private final LoopedTracker loopedTracker; + private final BitSet visited; + private final PriorityQueue changeQueue; + + AffectedEntitiesUpdater(BaseTopologicalOrderGraph graph, List> instanceList, + Function>> entityVariablePairFunction, + ChangedVariableNotifier changedVariableNotifier) { + this.graph = graph; + this.instanceList = instanceList; + this.entityVariablePairFunction = entityVariablePairFunction; + this.changedVariableNotifier = changedVariableNotifier; + var instanceCount = instanceList.size(); + this.affectedEntities = new AffectedEntities<>(this::updateLoopedStatusOfAffectedEntity); + this.loopedTracker = new LoopedTracker(instanceCount); + this.visited = new BitSet(instanceCount); + this.changeQueue = new PriorityQueue<>(instanceCount); + } + + @Override + public void accept(BitSet changed) { + initializeChangeQueue(changed); + + while (!changeQueue.isEmpty()) { + var nextNode = changeQueue.poll().nodeId(); + if (visited.get(nextNode)) { + continue; + } + visited.set(nextNode); + var shadowVariable = instanceList.get(nextNode); + var isChanged = updateShadowVariable(shadowVariable, graph.isLooped(loopedTracker, nextNode)); + + if (isChanged) { + var iterator = graph.nodeForwardEdges(nextNode); + while (iterator.hasNext()) { + var nextNodeForwardEdge = iterator.nextInt(); + if (!visited.get(nextNodeForwardEdge)) { + changeQueue.add(graph.getTopologicalOrder(nextNodeForwardEdge)); + } + } + } + } + + affectedEntities.processAndClear(); + // Prepare for the next time updateChanged() is called. + // No need to clear changeQueue, as that already finishes empty. + loopedTracker.clear(); + visited.clear(); + } + + private void initializeChangeQueue(BitSet changed) { + // BitSet iteration: get the first set bit at or after 0, + // then get the first set bit after that bit. + // Iteration ends when nextSetBit returns -1. + // This has the potential to overflow, since to do the + // test, we necessarily need to do nextSetBit(i + 1), + // and i + 1 can be negative if Integer.MAX_VALUE is set + // in the BitSet. + // This should never happen, since arrays in Java are limited + // to slightly less than Integer.MAX_VALUE. + for (var i = changed.nextSetBit(0); i >= 0; i = changed.nextSetBit(i + 1)) { + changeQueue.add(graph.getTopologicalOrder(i)); + if (i == Integer.MAX_VALUE) { + break; // or (i+1) would overflow + } + } + changed.clear(); + } + + private void updateLoopedStatusOfAffectedEntity(Object affectedEntity) { + ShadowVariableLoopedVariableDescriptor shadowVariableLoopedDescriptor = null; + var isEntityLooped = false; + for (var node : entityVariablePairFunction.apply(affectedEntity)) { + // All variables come from the same entity, + // therefore all have the same looped marker. + shadowVariableLoopedDescriptor = node.variableReference().shadowVariableLoopedDescriptor(); + if (graph.isLooped(loopedTracker, node.graphNodeId())) { + isEntityLooped = true; + break; + } + } + if (shadowVariableLoopedDescriptor == null) { + // At this point, affectedEntity is guaranteed to have looped marker. + // Otherwise AffectedEntities would not have sent it here. + throw new IllegalStateException("Impossible state: loop marker descriptor does not exist."); + } + var oldValue = shadowVariableLoopedDescriptor.getValue(affectedEntity); + if (!Objects.equals(oldValue, isEntityLooped)) { + changeShadowVariableAndNotify(shadowVariableLoopedDescriptor, affectedEntity, isEntityLooped); + } + + } + + private boolean updateShadowVariable(EntityVariablePair entityVariable, boolean isLooped) { + var entity = entityVariable.entity(); + var shadowVariableReference = entityVariable.variableReference(); + var oldValue = shadowVariableReference.memberAccessor().executeGetter(entity); + + if (isLooped) { + // null might be a valid value, and thus it could be the case + // that is was not looped and null, then turned to looped and null, + // which is still considered a change. + affectedEntities.add(entityVariable); + if (oldValue != null) { + changeShadowVariableAndNotify(shadowVariableReference, entity, null); + } + return true; + } else { + var newValue = shadowVariableReference.calculator().apply(entity); + if (!Objects.equals(oldValue, newValue)) { + affectedEntities.add(entityVariable); + changeShadowVariableAndNotify(shadowVariableReference, entity, newValue); + return true; + } + } + return false; + } + + private void changeShadowVariableAndNotify(VariableUpdaterInfo shadowVariableReference, Object entity, + Object newValue) { + var variableDescriptor = shadowVariableReference.variableDescriptor(); + changeShadowVariableAndNotify(variableDescriptor, entity, newValue); + } + + private void changeShadowVariableAndNotify(VariableDescriptor variableDescriptor, Object entity, + Object newValue) { + changedVariableNotifier.beforeVariableChanged().accept(variableDescriptor, entity); + variableDescriptor.setValue(entity, newValue); + changedVariableNotifier.afterVariableChanged().accept(variableDescriptor, entity); + } + + private static final class AffectedEntities { + + private final Consumer consumer; + private final Set entitiesForLoopedVarUpdateSet; + + public AffectedEntities(Consumer consumer) { + this.consumer = consumer; + this.entitiesForLoopedVarUpdateSet = new LinkedIdentityHashSet<>(); + } + + public void add(EntityVariablePair shadowVariable) { + var shadowVariableLoopedDescriptor = shadowVariable.variableReference().shadowVariableLoopedDescriptor(); + if (shadowVariableLoopedDescriptor == null) { + return; + } + entitiesForLoopedVarUpdateSet.add(shadowVariable.entity()); + } + + public void processAndClear() { + for (var entity : entitiesForLoopedVarUpdateSet) { + consumer.accept(entity); + } + entitiesForLoopedVarUpdateSet.clear(); + } + + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/BaseTopologicalOrderGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/BaseTopologicalOrderGraph.java new file mode 100644 index 00000000000..c68f6fadc34 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/BaseTopologicalOrderGraph.java @@ -0,0 +1,76 @@ +package ai.timefold.solver.core.impl.domain.variable.declarative; + +import java.util.PrimitiveIterator; + +/** + * Exists to expose read-only view of {@link TopologicalOrderGraph}. + */ +public interface BaseTopologicalOrderGraph { + + /** + * Return an iterator of the nodes that have the `from` node as a predecessor. + * + * @param from The predecessor node. + * @return an iterator of nodes with from as a predecessor. + */ + PrimitiveIterator.OfInt nodeForwardEdges(int from); + + /** + * Returns true if a given node is in a strongly connected component with a size + * greater than 1 (i.e. is in a loop) or is a transitive successor of a + * node with the above property. + * + * @param loopedTracker a tracker that can be used to record looped state to avoid + * recomputation. + * @param node The node being queried + * @return true if `node` is in a loop, false otherwise. + */ + boolean isLooped(LoopedTracker loopedTracker, int node); + + /** + * Returns a tuple containing node ID and a number corresponding to its topological order. + * In particular, after {@link TopologicalOrderGraph#commitChanges()} is called, the following + * must be true for any pair of nodes A, B where: + *
    + *
  • A is a predecessor of B
  • + *
  • `isLooped(A) == isLooped(B) == false`
  • + *
+ * getTopologicalOrder(A) < getTopologicalOrder(B) + *

+ * Said number may not be unique. + */ + NodeTopologicalOrder getTopologicalOrder(int node); + + /** + * Stores a graph node id along its topological order. + * Comparisons ignore node id and only use the topological order. + * For instance, for x = (0, 0) and y = (1, 5), x is before y, whereas for + * x = (0, 5) and y = (1, 0), y is before x. Note {@link BaseTopologicalOrderGraph} + * is not guaranteed to return every topological order index (i.e. + * it might be the case no nodes has order 0). + */ + record NodeTopologicalOrder(int nodeId, int order) + implements + Comparable { + + @Override + public int compareTo(NodeTopologicalOrder other) { + return order - other.order; + } + + @Override + public boolean equals(Object o) { + if (o instanceof NodeTopologicalOrder other) { + return nodeId == other.nodeId; + } + return false; + } + + @Override + public int hashCode() { + return nodeId; + } + + } + +} \ No newline at end of file diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSession.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSession.java index 5e91d4cd0c0..56dac46cfdb 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSession.java @@ -7,6 +7,7 @@ @NullMarked public class DefaultShadowVariableSession implements Supply { + final VariableReferenceGraph graph; public DefaultShadowVariableSession(VariableReferenceGraph graph) { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSessionFactory.java index 687c7b0ce29..66384249f61 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultShadowVariableSessionFactory.java @@ -1,7 +1,6 @@ package ai.timefold.solver.core.impl.domain.variable.declarative; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; @@ -30,59 +29,62 @@ public DefaultShadowVariableSessionFactory( this.graphCreator = graphCreator; } - public static void visitGraph( + @SuppressWarnings("unchecked") + public static VariableReferenceGraph buildGraph( SolutionDescriptor solutionDescriptor, - VariableReferenceGraph variableReferenceGraph, Object[] entities, + VariableReferenceGraphBuilder variableReferenceGraphBuilder, Object[] entities, IntFunction graphCreator) { var declarativeShadowVariableDescriptors = solutionDescriptor.getDeclarativeShadowVariableDescriptors(); - var variableIdToUpdater = new HashMap, VariableUpdaterInfo>(); + if (declarativeShadowVariableDescriptors.isEmpty()) { + return EmptyVariableReferenceGraph.INSTANCE; + } + var variableIdToUpdater = new HashMap, VariableUpdaterInfo>(); + // Create graph node for each entity/declarative shadow variable pair. // Maps a variable id to it source aliases; // For instance, "previousVisit.startTime" is a source alias of "startTime" // One way to view this concept is "previousVisit.startTime" is a pointer // to "startTime" of some visit, and thus alias it. - Map, Set> declarativeShadowVariableToAliasMap = new HashMap<>(); - - // Create graph node for each entity/declarative shadow variable pair - createGraphNodes(variableReferenceGraph, entities, declarativeShadowVariableDescriptors, variableIdToUpdater, - declarativeShadowVariableToAliasMap); + var declarativeShadowVariableToAliasMap = createGraphNodes(variableReferenceGraphBuilder, entities, + declarativeShadowVariableDescriptors, variableIdToUpdater); // Create variable processors for each declarative shadow variable descriptor for (var declarativeShadowVariable : declarativeShadowVariableDescriptors) { - final var fromVariableId = declarativeShadowVariable.getVariableMetaModel(); - createSourceChangeProcessors(variableReferenceGraph, declarativeShadowVariable, fromVariableId); - createAliasToVariableChangeProcessors(variableReferenceGraph, declarativeShadowVariableToAliasMap, fromVariableId); + var fromVariableId = declarativeShadowVariable.getVariableMetaModel(); + createSourceChangeProcessors(variableReferenceGraphBuilder, declarativeShadowVariable, fromVariableId); + var aliasSet = declarativeShadowVariableToAliasMap.get(fromVariableId); + if (aliasSet != null) { + createAliasToVariableChangeProcessors(variableReferenceGraphBuilder, aliasSet, fromVariableId); + } } // Create the fixed edges in the graph - createFixedVariableRelationEdges(variableReferenceGraph, entities, declarativeShadowVariableDescriptors); - variableReferenceGraph.createGraph(graphCreator); + createFixedVariableRelationEdges(variableReferenceGraphBuilder, entities, declarativeShadowVariableDescriptors); + return variableReferenceGraphBuilder.build(graphCreator); } - private static void createGraphNodes(VariableReferenceGraph graph, Object[] entities, + private static Map, Set> createGraphNodes( + VariableReferenceGraphBuilder graph, Object[] entities, List> declarativeShadowVariableDescriptors, - Map, VariableUpdaterInfo> variableIdToUpdater, - Map, Set> declarativeShadowVariableToAliasMap) { + Map, VariableUpdaterInfo> variableIdToUpdater) { + var result = new HashMap, Set>(); for (var entity : entities) { for (var declarativeShadowVariableDescriptor : declarativeShadowVariableDescriptors) { var entityClass = declarativeShadowVariableDescriptor.getEntityDescriptor().getEntityClass(); if (entityClass.isInstance(entity)) { var variableId = declarativeShadowVariableDescriptor.getVariableMetaModel(); - var updater = variableIdToUpdater.computeIfAbsent(variableId, ignored -> new VariableUpdaterInfo( + var updater = variableIdToUpdater.computeIfAbsent(variableId, ignored -> new VariableUpdaterInfo<>( + variableId, declarativeShadowVariableDescriptor, declarativeShadowVariableDescriptor.getEntityDescriptor().getShadowVariableLoopedDescriptor(), declarativeShadowVariableDescriptor.getMemberAccessor(), declarativeShadowVariableDescriptor.getCalculator()::executeGetter)); - graph.addVariableReferenceEntity( - variableId, - entity, - updater); + graph.addVariableReferenceEntity(entity, updater); for (var sourceRoot : declarativeShadowVariableDescriptor.getSources()) { for (var source : sourceRoot.variableSourceReferences()) { if (source.downstreamDeclarativeVariableMetamodel() != null) { - declarativeShadowVariableToAliasMap - .computeIfAbsent(source.downstreamDeclarativeVariableMetamodel(), - ignored -> new LinkedHashSet<>()) + result.computeIfAbsent(source.downstreamDeclarativeVariableMetamodel(), + ignored -> new LinkedHashSet<>()) .add(source); } } @@ -90,9 +92,11 @@ private static void createGraphNodes(VariableReferenceGraph void createSourceChangeProcessors(VariableReferenceGraph variableReferenceGraph, + private static void createSourceChangeProcessors( + VariableReferenceGraphBuilder variableReferenceGraphBuilder, DeclarativeShadowVariableDescriptor declarativeShadowVariable, VariableMetaModel fromVariableId) { for (var source : declarativeShadowVariable.getSources()) { @@ -104,13 +108,11 @@ private static void createSourceChangeProcessors(VariableReferenceGr // non-declarative variables are not in the graph and must have their // own processor if (!sourcePart.isDeclarative()) { - variableReferenceGraph.addAfterProcessor(toVariableId, (graph, entity) -> { + variableReferenceGraphBuilder.addAfterProcessor(toVariableId, (graph, entity) -> { // Exploits the fact the source entity and the target entity must be the same, - // since non-declarative variables can only be accessed from the root entity - // i.e. paths like "otherVisit.previous" - // or "visitGroup[].otherVisit.previous" are not allowed, - // but paths like "previous" or - // "visitGroup[].previous" are. + // since non-declarative variables can only be accessed from the root entity; + // paths like "otherVisit.previous" or "visitGroup[].otherVisit.previous" are not allowed, + // but paths like "previous" or "visitGroup[].previous" are. // Without this invariant, an inverse set must be calculated // and maintained, // and this code is complicated enough. @@ -125,71 +127,78 @@ private static void createSourceChangeProcessors(VariableReferenceGr } private static void createAliasToVariableChangeProcessors( - VariableReferenceGraph variableReferenceGraph, - Map, Set> declarativeShadowVariableToAliasMap, + VariableReferenceGraphBuilder variableReferenceGraphBuilder, Set aliasSet, VariableMetaModel fromVariableId) { - for (var alias : declarativeShadowVariableToAliasMap.getOrDefault(fromVariableId, Collections.emptySet())) { + for (var alias : aliasSet) { var toVariableId = alias.targetVariableMetamodel(); var sourceVariableId = alias.variableMetaModel(); if (!alias.isDeclarative() && alias.affectGraphEdges()) { // Exploit the same fact as above - variableReferenceGraph.addBeforeProcessor(sourceVariableId, - (graph, toEntity) -> alias.targetEntityFunctionStartingFromVariableEntity() - .accept(toEntity, fromEntity -> { - // from/to can be null in extended models - // ex: previous is used as a source, but only an extended class - // has the to variable - var from = graph.lookupOrNull(fromVariableId, fromEntity); - if (from == null) { - return; - } - var to = graph.lookupOrNull(toVariableId, toEntity); - if (to == null) { - return; - } - graph.removeEdge(from, to); - })); - variableReferenceGraph.addAfterProcessor(sourceVariableId, - (graph, toEntity) -> alias.targetEntityFunctionStartingFromVariableEntity() - .accept(toEntity, fromEntity -> { - var from = graph.lookupOrNull(fromVariableId, fromEntity); - if (from == null) { - return; - } - var to = graph.lookupOrNull(toVariableId, toEntity); - if (to == null) { - return; - } - graph.addEdge(from, to); - })); + variableReferenceGraphBuilder.addBeforeProcessor(sourceVariableId, + (graph, toEntity) -> { + // from/to can be null in extended models + // ex: previous is used as a source, but only an extended class + // has the to variable + var to = graph.lookupOrNull(toVariableId, toEntity); + if (to == null) { + return; + } + var fromEntity = alias.targetEntityFunctionStartingFromVariableEntity() + .apply(toEntity); + if (fromEntity == null) { + return; + } + var from = graph.lookupOrNull(fromVariableId, fromEntity); + if (from == null) { + return; + } + graph.removeEdge(from, to); + }); + variableReferenceGraphBuilder.addAfterProcessor(sourceVariableId, + (graph, toEntity) -> { + var to = graph.lookupOrNull(toVariableId, toEntity); + if (to == null) { + return; + } + var fromEntity = alias.findTargetEntity(toEntity); + if (fromEntity == null) { + return; + } + var from = graph.lookupOrNull(fromVariableId, fromEntity); + if (from == null) { + return; + } + graph.addEdge(from, to); + }); } // Note: it is impossible to have a declarative variable affect graph edges, // since accessing a declarative variable from another declarative variable is prohibited. } } - private static void createFixedVariableRelationEdges(VariableReferenceGraph variableReferenceGraph, + private static void createFixedVariableRelationEdges( + VariableReferenceGraphBuilder variableReferenceGraphBuilder, Object[] entities, List> declarativeShadowVariableDescriptors) { for (var entity : entities) { for (var declarativeShadowVariableDescriptor : declarativeShadowVariableDescriptors) { var entityClass = declarativeShadowVariableDescriptor.getEntityDescriptor().getEntityClass(); - if (entityClass.isInstance(entity)) { - var toVariableId = declarativeShadowVariableDescriptor.getVariableMetaModel(); - for (var sourceRoot : declarativeShadowVariableDescriptor.getSources()) { - for (var source : sourceRoot.variableSourceReferences()) { - if (source.isTopLevel() && source.isDeclarative()) { - var fromVariableId = source.variableMetaModel(); - - sourceRoot.valueEntityFunction() - .accept(entity, fromEntity -> variableReferenceGraph.addFixedEdge( - variableReferenceGraph - .lookupOrError(fromVariableId, fromEntity), - variableReferenceGraph - .lookupOrError(toVariableId, entity))); - break; - } + if (!entityClass.isInstance(entity)) { + continue; + } + var toVariableId = declarativeShadowVariableDescriptor.getVariableMetaModel(); + var to = variableReferenceGraphBuilder.lookupOrError(toVariableId, entity); + for (var sourceRoot : declarativeShadowVariableDescriptor.getSources()) { + for (var source : sourceRoot.variableSourceReferences()) { + if (source.isTopLevel() && source.isDeclarative()) { + var fromVariableId = source.variableMetaModel(); + sourceRoot.valueEntityFunction() + .accept(entity, fromEntity -> { + var from = variableReferenceGraphBuilder.lookupOrError(fromVariableId, fromEntity); + variableReferenceGraphBuilder.addFixedEdge(from, to); + }); + break; } } } @@ -204,10 +213,8 @@ public DefaultShadowVariableSession forSolution(Solution_ solution) { } public DefaultShadowVariableSession forEntities(Object... entities) { - var variableReferenceGraph = new VariableReferenceGraph<>(ChangedVariableNotifier.of(scoreDirector)); - - visitGraph(solutionDescriptor, variableReferenceGraph, entities, graphCreator); - - return new DefaultShadowVariableSession<>(variableReferenceGraph); + var builder = new VariableReferenceGraphBuilder<>(ChangedVariableNotifier.of(scoreDirector)); + var graph = buildGraph(solutionDescriptor, builder, entities, graphCreator); + return new DefaultShadowVariableSession<>(graph); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultTopologicalOrderGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultTopologicalOrderGraph.java index 0416d957c46..1359d97baf7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultTopologicalOrderGraph.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultTopologicalOrderGraph.java @@ -12,39 +12,49 @@ import ai.timefold.solver.core.impl.util.MutableInt; public class DefaultTopologicalOrderGraph implements TopologicalOrderGraph { - private final int[] ord; + + private final NodeTopologicalOrder[] nodeIdToTopologicalOrderMap; private final Map> componentMap; private final Set[] forwardEdges; private final Set[] backEdges; @SuppressWarnings({ "unchecked" }) public DefaultTopologicalOrderGraph(final int size) { - this.ord = new int[size]; + this.nodeIdToTopologicalOrderMap = new NodeTopologicalOrder[size]; this.componentMap = CollectionUtils.newLinkedHashMap(size); this.forwardEdges = new Set[size]; this.backEdges = new Set[size]; for (var i = 0; i < size; i++) { forwardEdges[i] = new HashSet<>(); backEdges[i] = new HashSet<>(); - ord[i] = i; + nodeIdToTopologicalOrderMap[i] = new NodeTopologicalOrder(i, i); } } @Override - public void addEdge(int from, int to) { - forwardEdges[from].add(to); - backEdges[to].add(from); + public void addEdge(int fromNode, int toNode) { + forwardEdges[fromNode].add(toNode); + backEdges[toNode].add(fromNode); + } + + @Override + public void removeEdge(int fromNode, int toNode) { + forwardEdges[fromNode].remove(toNode); + backEdges[toNode].remove(fromNode); } @Override - public void removeEdge(int from, int to) { - forwardEdges[from].remove(to); - backEdges[to].remove(from); + public void forEachEdge(EdgeConsumer edgeConsumer) { + for (var fromNode = 0; fromNode < forwardEdges.length; fromNode++) { + for (var toNode : forwardEdges[fromNode]) { + edgeConsumer.accept(fromNode, toNode); + } + } } @Override - public PrimitiveIterator.OfInt nodeForwardEdges(int from) { - return componentMap.get(from).stream() + public PrimitiveIterator.OfInt nodeForwardEdges(int fromNode) { + return componentMap.get(fromNode).stream() .flatMap(member -> forwardEdges[member].stream()) .mapToInt(Integer::intValue) .distinct().iterator(); @@ -73,12 +83,12 @@ public boolean isLooped(LoopedTracker loopedTracker, int node) { } @Override - public int getTopologicalOrder(int node) { - return ord[node]; + public NodeTopologicalOrder getTopologicalOrder(int node) { + return nodeIdToTopologicalOrderMap[node]; } @Override - public void endBatchChange() { + public void commitChanges() { var index = new MutableInt(1); var stackIndex = new MutableInt(0); var size = forwardEdges.length; @@ -100,7 +110,7 @@ public void endBatchChange() { var component = components.get(i); var componentNodes = new ArrayList(component.cardinality()); for (var node = component.nextSetBit(0); node >= 0; node = component.nextSetBit(node + 1)) { - ord[node] = ordIndex; + nodeIdToTopologicalOrderMap[node] = new NodeTopologicalOrder(node, ordIndex); componentNodes.add(node); componentMap.put(node, componentNodes); ordIndex++; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultVariableReferenceGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultVariableReferenceGraph.java new file mode 100644 index 00000000000..2b5fb00034c --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/DefaultVariableReferenceGraph.java @@ -0,0 +1,193 @@ +package ai.timefold.solver.core.impl.domain.variable.declarative; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.IntFunction; +import java.util.stream.Collectors; + +import ai.timefold.solver.core.impl.util.DynamicIntArray; +import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; + +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +final class DefaultVariableReferenceGraph implements VariableReferenceGraph { + + // These structures are immutable. + private final List> instanceList; + private final Map, Map>> variableReferenceToInstanceMap; + private final Map, List, Object>>> variableReferenceToBeforeProcessor; + private final Map, List, Object>>> variableReferenceToAfterProcessor; + + // These structures are mutable. + private final DynamicIntArray[] edgeCount; + private final TopologicalOrderGraph graph; + private final BitSet changed; + + private final Consumer affectedEntitiesUpdater; + + public DefaultVariableReferenceGraph(VariableReferenceGraphBuilder outerGraph, + IntFunction graphCreator) { + instanceList = List.copyOf(outerGraph.instanceList); + var instanceCount = instanceList.size(); + // Often the maps are a singleton; we improve performance by actually making it so. + variableReferenceToInstanceMap = mapOfMapsDeepCopyOf(outerGraph.variableReferenceToInstanceMap); + variableReferenceToBeforeProcessor = mapOfListsDeepCopyOf(outerGraph.variableReferenceToBeforeProcessor); + variableReferenceToAfterProcessor = mapOfListsDeepCopyOf(outerGraph.variableReferenceToAfterProcessor); + edgeCount = new DynamicIntArray[instanceCount]; + for (int i = 0; i < instanceCount; i++) { + edgeCount[i] = new DynamicIntArray(instanceCount); + } + graph = graphCreator.apply(instanceCount); + graph.withNodeData(instanceList); + changed = new BitSet(instanceCount); + + var entityToVariableReferenceMap = new IdentityHashMap>>(); + var visited = Collections.newSetFromMap(new IdentityHashMap<>()); + for (var instance : instanceList) { + var entity = instance.entity(); + if (visited.add(entity)) { + for (var variableId : outerGraph.variableReferenceToAfterProcessor.keySet()) { + afterVariableChanged(variableId, entity); + } + } + entityToVariableReferenceMap.computeIfAbsent(entity, ignored -> new ArrayList<>()) + .add(instance); + } + for (var fixedEdgeEntry : outerGraph.fixedEdges.entrySet()) { + for (var toEdge : fixedEdgeEntry.getValue()) { + addEdge(fixedEdgeEntry.getKey(), toEdge); + } + } + // Immutable optimized version of the map, now that it won't be updated anymore. + var immutableEntityToVariableReferenceMap = mapOfListsDeepCopyOf(entityToVariableReferenceMap); + // This mutable structure is created once, and reused from there on. + // Otherwise its internal collections were observed being re-created so often + // that the allocation of arrays would become a major bottleneck. + affectedEntitiesUpdater = new AffectedEntitiesUpdater<>(graph, instanceList, immutableEntityToVariableReferenceMap::get, + outerGraph.changedVariableNotifier); + } + + @Override + public @Nullable EntityVariablePair lookupOrNull(VariableMetaModel variableId, Object entity) { + var map = variableReferenceToInstanceMap.get(variableId); + if (map == null) { + return null; + } + return map.get(entity); + } + + @Override + public void addEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { + var fromNodeId = from.graphNodeId(); + var toNodeId = to.graphNodeId(); + if (fromNodeId == toNodeId) { + return; + } + + var count = edgeCount[fromNodeId].get(toNodeId); + if (count == 0) { + graph.addEdge(fromNodeId, toNodeId); + } + edgeCount[fromNodeId].set(toNodeId, count + 1); + markChanged(to); + } + + @Override + public void removeEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { + var fromNodeId = from.graphNodeId(); + var toNodeId = to.graphNodeId(); + if (fromNodeId == toNodeId) { + return; + } + + var count = edgeCount[fromNodeId].get(toNodeId); + if (count == 1) { + graph.removeEdge(fromNodeId, toNodeId); + } + edgeCount[fromNodeId].set(toNodeId, count - 1); + markChanged(to); + } + + @Override + public void markChanged(@NonNull EntityVariablePair node) { + changed.set(node.graphNodeId()); + } + + @Override + public void updateChanged() { + if (changed.isEmpty()) { + return; + } + graph.commitChanges(); + affectedEntitiesUpdater.accept(changed); + } + + @Override + public void beforeVariableChanged(VariableMetaModel variableReference, Object entity) { + if (variableReference.entity().type().isInstance(entity)) { + processEntity(variableReferenceToBeforeProcessor.getOrDefault(variableReference, Collections.emptyList()), entity); + } + } + + private void processEntity(List, Object>> processorList, Object entity) { + var processorCount = processorList.size(); + // Avoid creation of iterators on the hot path. + // The short-lived instances were observed to cause considerable GC pressure. + for (int i = 0; i < processorCount; i++) { + processorList.get(i).accept(this, entity); + } + } + + @Override + public void afterVariableChanged(VariableMetaModel variableReference, Object entity) { + if (variableReference.entity().type().isInstance(entity)) { + var node = lookupOrNull(variableReference, entity); + if (node != null) { + markChanged(node); + } + processEntity(variableReferenceToAfterProcessor.getOrDefault(variableReference, Collections.emptyList()), entity); + } + } + + @Override + public String toString() { + var edgeList = new LinkedHashMap, List>>(); + graph.forEachEdge((from, to) -> edgeList.computeIfAbsent(instanceList.get(from), k -> new ArrayList<>()) + .add(instanceList.get(to))); + return edgeList.entrySet() + .stream() + .map(e -> e.getKey() + "->" + e.getValue()) + .collect(Collectors.joining( + "," + System.lineSeparator() + " ", + "{" + System.lineSeparator() + " ", + "}")); + + } + + @SuppressWarnings("unchecked") + private static Map> mapOfMapsDeepCopyOf(Map> map) { + var entryArray = map.entrySet() + .stream() + .map(e -> Map.entry(e.getKey(), Map.copyOf(e.getValue()))) + .toArray(Map.Entry[]::new); + return Map.ofEntries(entryArray); + } + + @SuppressWarnings("unchecked") + private static Map> mapOfListsDeepCopyOf(Map> map) { + var entryArray = map.entrySet() + .stream() + .map(e -> Map.entry(e.getKey(), List.copyOf(e.getValue()))) + .toArray(Map.Entry[]::new); + return Map.ofEntries(entryArray); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EmptyVariableReferenceGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EmptyVariableReferenceGraph.java new file mode 100644 index 00000000000..6cd6b2db24d --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EmptyVariableReferenceGraph.java @@ -0,0 +1,53 @@ +package ai.timefold.solver.core.impl.domain.variable.declarative; + +import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; + +import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; + +final class EmptyVariableReferenceGraph implements VariableReferenceGraph { + + @SuppressWarnings("rawtypes") + public static final EmptyVariableReferenceGraph INSTANCE = new EmptyVariableReferenceGraph<>(); + + @Override + public @Nullable EntityVariablePair lookupOrNull(VariableMetaModel variableId, Object entity) { + return null; + } + + @Override + public void addEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { + throw new IllegalStateException("Impossible state: cannot modify an empty graph."); + } + + @Override + public void removeEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { + throw new IllegalStateException("Impossible state: cannot modify an empty graph."); + } + + @Override + public void markChanged(@NonNull EntityVariablePair node) { + throw new IllegalStateException("Impossible state: cannot modify an empty graph."); + } + + @Override + public void updateChanged() { + // No need to do anything. + } + + @Override + public void beforeVariableChanged(VariableMetaModel variableReference, Object entity) { + // No need to do anything. + } + + @Override + public void afterVariableChanged(VariableMetaModel variableReference, Object entity) { + // No need to do anything. + } + + @Override + public String toString() { + return "{}"; + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EntityVariablePair.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EntityVariablePair.java index a3185f50264..da94bcd609c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EntityVariablePair.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/EntityVariablePair.java @@ -1,28 +1,23 @@ package ai.timefold.solver.core.impl.domain.variable.declarative; -import java.util.Objects; - -import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; - import org.jspecify.annotations.NullMarked; @NullMarked -public record EntityVariablePair(Object entity, VariableMetaModel variableId, - VariableUpdaterInfo variableReference, int graphNodeId) { +public record EntityVariablePair(Object entity, VariableUpdaterInfo variableReference, int graphNodeId) { @Override public boolean equals(Object object) { - if (!(object instanceof EntityVariablePair that)) + if (!(object instanceof EntityVariablePair that)) return false; return graphNodeId == that.graphNodeId; } @Override public int hashCode() { - return Objects.hashCode(graphNodeId); + return Integer.hashCode(graphNodeId); } @Override public String toString() { - return entity + ":" + variableId; + return entity + ":" + variableReference.id(); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/LoopedTracker.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/LoopedTracker.java index 95a54c7ee46..7ea9ad0ab2e 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/LoopedTracker.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/LoopedTracker.java @@ -1,23 +1,42 @@ package ai.timefold.solver.core.impl.domain.variable.declarative; -import java.util.Arrays; +import java.util.BitSet; import org.jspecify.annotations.NullMarked; @NullMarked public final class LoopedTracker { - private final LoopedStatus[] statuses; - public LoopedTracker(int count) { - statuses = new LoopedStatus[count]; - Arrays.fill(statuses, LoopedStatus.UNKNOWN); + // Simple LoopedStatus[] array would have occupied too much memory with large node counts. + // Furthermore, allocating and/or clearing these large arrays is expensive as well. + private final BitSet present; + private final BitSet looped; + + public LoopedTracker(int nodeCount) { + this.present = new BitSet(nodeCount); + this.looped = new BitSet(nodeCount); } public void mark(int node, LoopedStatus status) { - statuses[node] = status; + if (status == LoopedStatus.UNKNOWN) { + present.clear(node); + looped.clear(node); + } else { + present.set(node); + looped.set(node, status == LoopedStatus.LOOPED); + } } public LoopedStatus status(int node) { - return statuses[node]; + if (present.isEmpty() || !present.get(node)) { + return LoopedStatus.UNKNOWN; + } + return looped.get(node) ? LoopedStatus.LOOPED : LoopedStatus.NOT_LOOPED; + } + + public void clear() { + present.clear(); + looped.clear(); } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSource.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSource.java index 45bac10ebc3..5423081c497 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSource.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSource.java @@ -19,6 +19,7 @@ import ai.timefold.solver.core.preview.api.domain.variable.declarative.ShadowSources; import org.jspecify.annotations.NonNull; +import org.jspecify.annotations.Nullable; public record RootVariableSource( Class rootEntity, @@ -32,18 +33,18 @@ private record VariablePath(Class variableEntityClass, String variableName, List memberAccessorsBeforeEntity, List memberAccessorsAfterEntity) { - public BiConsumer> getValueVisitorFromVariableEntity() { - return (entity, consumer) -> { - var currentEntity = entity; - for (var member : memberAccessorsAfterEntity) { - currentEntity = member.executeGetter(currentEntity); - if (currentEntity == null) { - return; - } + + public @Nullable Object findTargetEntity(Object entity) { + var currentEntity = entity; + for (var member : memberAccessorsAfterEntity) { + currentEntity = member.executeGetter(currentEntity); + if (currentEntity == null) { + return null; } - consumer.accept(currentEntity); - }; + } + return currentEntity; } + } public static Iterator pathIterator(Class rootEntity, String path) { @@ -215,7 +216,7 @@ public static RootVariableSource from( isDeclarativeShadowVariable(variableMemberAccessor), solutionMetaModel.entity(rootEntityClass).variable(targetVariableName), downstreamDeclarativeVariable, - sourceVariablePath.getValueVisitorFromVariableEntity()); + sourceVariablePath::findTargetEntity); } private static void assertIsValidVariableReference(Class rootEntityClass, String variablePath, diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/TopologicalOrderGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/TopologicalOrderGraph.java index d097dde51b0..d1c53e0dbce 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/TopologicalOrderGraph.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/TopologicalOrderGraph.java @@ -1,24 +1,15 @@ package ai.timefold.solver.core.impl.domain.variable.declarative; import java.util.List; -import java.util.PrimitiveIterator; -public interface TopologicalOrderGraph { - /** - * Called on the first edge modification of a batch. - */ - default void startBatchChange() { - } +public interface TopologicalOrderGraph extends BaseTopologicalOrderGraph { /** - * Called when all edge modifications are done. - * There is no prior {@link #startBatchChange()} call if - * no modifications were done. + * Called when all edge modifications are queued. * After this method returns, {@link #getTopologicalOrder(int)} * must be accurate for every node in the graph. */ - default void endBatchChange() { - } + void commitChanges(); /** * Called on graph creation to supply metadata about the graph nodes. @@ -26,11 +17,12 @@ default void endBatchChange() { * @param nodes A list of entity/variable pairs, where the nth item in the list * corresponds to the node with id n in the graph. */ - default void withNodeData(List nodes) { + default void withNodeData(List> nodes) { } /** * Called when a graph edge is added. + * The operation is added to a batch and only executed when {@link #commitChanges()} is called. *

* {@link #getTopologicalOrder(int)} is allowed to be invalid * when this method returns. @@ -39,43 +31,20 @@ default void withNodeData(List nodes) { /** * Called when a graph edge is removed. + * The operation is added to a batch and only executed when {@link #commitChanges()} is called. *

* {@link #getTopologicalOrder(int)} is allowed to be invalid * when this method returns. */ void removeEdge(int from, int to); - /** - * Return an iterator of the nodes that have the `from` node as a predecessor. - * - * @param from The predecessor node. - * @return an iterator of nodes with from as a predecessor. - */ - PrimitiveIterator.OfInt nodeForwardEdges(int from); + void forEachEdge(EdgeConsumer edgeConsumer); - /** - * Returns true is a given node is in a strongly connected component with a size - * greater than 1 (i.e. is in a loop) or is a transitive successor of a - * node with the above property. - * - * @param loopedTracker a tracker that can be used to record looped state to avoid - * recomputation. - * @param node The node being queried - * @return true if `node` is in a loop, false otherwise. - */ - boolean isLooped(LoopedTracker loopedTracker, int node); + @FunctionalInterface + interface EdgeConsumer { + + void accept(int from, int to); + + } - /** - * Returns a number corresponding to the topological order of a node. - * In particular, after {@link #endBatchChange()} is called, the following - * must be true for any pair of nodes A, B where: - *

    - *
  • A is a predecessor of B
  • - *
  • `isLooped(A) == isLooped(B) == false`
  • - *
- * getTopologicalOrder(A) < getTopologicalOrder(B) - *

- * Said number may not be unique. - */ - int getTopologicalOrder(int node); } \ No newline at end of file diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraph.java index 8bafade7928..0d12ff2e882 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraph.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraph.java @@ -1,356 +1,26 @@ package ai.timefold.solver.core.impl.domain.variable.declarative; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.Collections; -import java.util.HashMap; -import java.util.IdentityHashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.PriorityQueue; -import java.util.Set; -import java.util.function.BiConsumer; -import java.util.function.IntFunction; - -import ai.timefold.solver.core.impl.domain.variable.descriptor.VariableDescriptor; import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; -public class VariableReferenceGraph { - private final Map, Map> variableReferenceToInstanceMap; - private final List instanceList; - private final ChangedVariableNotifier changedVariableNotifier; - - private final Map, List, Object>>> variableReferenceToBeforeProcessor; - private final Map, List, Object>>> variableReferenceToAfterProcessor; - private final Map> fixedEdges; - private final IdentityHashMap> entityToVariableReferenceMap; - private int[][] counts; - private TopologicalOrderGraph graph; - private BitSet changed; - - public VariableReferenceGraph(ChangedVariableNotifier changedVariableNotifier) { - this.changedVariableNotifier = changedVariableNotifier; - variableReferenceToInstanceMap = new HashMap<>(); - instanceList = new ArrayList<>(); - variableReferenceToBeforeProcessor = new HashMap<>(); - variableReferenceToAfterProcessor = new HashMap<>(); - fixedEdges = new HashMap<>(); - entityToVariableReferenceMap = new IdentityHashMap<>(); - } - - public EntityVariablePair addVariableReferenceEntity( - VariableMetaModel variableId, - Entity_ entity, - VariableUpdaterInfo variableReference) { - if (variableReferenceToInstanceMap.containsKey(variableId) && - variableReferenceToInstanceMap.get(variableId).containsKey(entity)) { - return variableReferenceToInstanceMap.get(variableId).get(entity); - } - var node = new EntityVariablePair(entity, variableId, - variableReference, instanceList.size()); - variableReferenceToInstanceMap.computeIfAbsent(variableId, k -> new IdentityHashMap<>()) - .put(entity, node); - instanceList.add(node); - return node; - } - - public void addBeforeProcessor(VariableMetaModel variableId, - BiConsumer, Object> consumer) { - variableReferenceToBeforeProcessor.computeIfAbsent(variableId, k -> new ArrayList<>()) - .add(consumer); - } - - public void addAfterProcessor(VariableMetaModel variableId, - BiConsumer, Object> consumer) { - variableReferenceToAfterProcessor.computeIfAbsent(variableId, k -> new ArrayList<>()) - .add(consumer); - } - - public void createGraph(IntFunction graphCreator) { - counts = new int[instanceList.size()][instanceList.size()]; - graph = graphCreator.apply(instanceList.size()); - graph.withNodeData(instanceList); - changed = new BitSet(instanceList.size()); - - graph.startBatchChange(); - var visited = Collections.newSetFromMap(new IdentityHashMap<>()); - for (var instance : instanceList) { - if (visited.add(instance.entity())) { - for (var variableId : variableReferenceToAfterProcessor.keySet()) { - if (variableId.entity().type().isInstance(instance.entity())) { - afterVariableChanged(variableId, instance.entity()); - } - } - } - entityToVariableReferenceMap.computeIfAbsent(instance.entity(), ignored -> new ArrayList<>()) - .add(instance); - } - for (var fixedEdgeEntry : fixedEdges.entrySet()) { - for (var toEdge : fixedEdgeEntry.getValue()) { - addEdge(fixedEdgeEntry.getKey(), toEdge); - } - } - } - - public @Nullable EntityVariablePair lookupOrNull(VariableMetaModel variableId, Object entity) { - return variableReferenceToInstanceMap.getOrDefault(variableId, Collections.emptyMap()).get(entity); - } - - public @NonNull EntityVariablePair lookupOrError(VariableMetaModel variableId, Object entity) { - var out = lookupOrNull(variableId, entity); - if (out == null) { - throw new IllegalArgumentException(); - } - return out; - } - - public void addFixedEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { - if (from.graphNodeId() == to.graphNodeId()) { - return; - } - fixedEdges.computeIfAbsent(from, k -> new ArrayList<>()).add(to); - } - - public void addEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { - if (from.graphNodeId() == to.graphNodeId()) { - return; - } - if (changed.isEmpty()) { - graph.startBatchChange(); - } - var oldCount = counts[from.graphNodeId()][to.graphNodeId()]++; - if (oldCount == 0) { - graph.addEdge(from.graphNodeId(), to.graphNodeId()); - } - - markChanged(to); - } - - public void removeEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { - if (from.graphNodeId() == to.graphNodeId()) { - return; - } - if (changed.isEmpty()) { - graph.startBatchChange(); - } - var newCount = --counts[from.graphNodeId()][to.graphNodeId()]; - if (newCount == 0) { - graph.removeEdge(from.graphNodeId(), to.graphNodeId()); - } - markChanged(to); - } - - public void markChanged(@NonNull EntityVariablePair node) { - if (changed.isEmpty()) { - graph.startBatchChange(); - } - changed.set(node.graphNodeId()); - } - - record AffectedEntity(Object entity, VariableUpdaterInfo variableUpdaterInfo) { - @Override - public boolean equals(Object o) { - if (o instanceof AffectedEntity other) { - return entity == other.entity; - } - return false; - } - - @Override - public int hashCode() { - return System.identityHashCode(entity); - } - } - - public void updateChanged() { - if (changed.isEmpty()) { - return; - } - graph.endBatchChange(); - var visited = new boolean[instanceList.size()]; - var loopedTracker = new LoopedTracker(visited.length); - var affectedEntities = Collections.newSetFromMap(new IdentityHashMap()); - var nodeHeap = createInitialChangeQueue(); - - while (!nodeHeap.isEmpty()) { - var nextNode = nodeHeap.poll().nodeId; - if (visited[nextNode]) { - continue; - } - visited[nextNode] = true; - var shadowVariable = instanceList.get(nextNode); - var isChanged = updateShadowVariable(shadowVariable, - graph.isLooped(loopedTracker, nextNode), - affectedEntities); - - if (isChanged) { - graph.nodeForwardEdges(nextNode).forEachRemaining( - (int node) -> { - if (!visited[node]) { - nodeHeap.add(new AffectedShadowVariable(node, graph.getTopologicalOrder(node))); - } - }); - } - } - - updateInvalidityStatusOfAffectedEntities(affectedEntities, loopedTracker); - } - - @SuppressWarnings("unchecked") - private boolean updateShadowVariable(EntityVariablePair shadowVariable, - boolean isLooped, - Set affectedEntities) { - var isChanged = false; - var entity = shadowVariable.entity(); - var shadowVariableReference = shadowVariable.variableReference(); - var oldValue = shadowVariableReference.memberAccessor().executeGetter(entity); - - if (isLooped) { - // null might be a valid value, and thus it could be the case - // that is was not looped and null, then turned to looped and null, - // which is still considered a change. - isChanged = true; - affectedEntities.add(new AffectedEntity(entity, shadowVariableReference)); - if (oldValue != null) { - changedVariableNotifier.beforeVariableChanged().accept( - (VariableDescriptor) shadowVariableReference.variableDescriptor(), entity); - shadowVariableReference.memberAccessor().executeSetter(entity, null); - changedVariableNotifier.afterVariableChanged().accept( - (VariableDescriptor) shadowVariableReference.variableDescriptor(), entity); - } - } else { - var newValue = shadowVariableReference.calculator().apply(entity); - - if (!Objects.equals(oldValue, newValue)) { - affectedEntities.add(new AffectedEntity(entity, shadowVariableReference)); - changedVariableNotifier.beforeVariableChanged().accept( - (VariableDescriptor) shadowVariableReference.variableDescriptor(), entity); - shadowVariableReference.memberAccessor().executeSetter(entity, newValue); - changedVariableNotifier.afterVariableChanged().accept( - (VariableDescriptor) shadowVariableReference.variableDescriptor(), entity); - isChanged = true; - } - } - return isChanged; - } +public sealed interface VariableReferenceGraph + permits DefaultVariableReferenceGraph, EmptyVariableReferenceGraph { - record AffectedShadowVariable(int nodeId, int topologicalIndex) implements Comparable { - @Override - public int compareTo(AffectedShadowVariable heapItem) { - return topologicalIndex - heapItem.topologicalIndex; - } + @Nullable + EntityVariablePair lookupOrNull(VariableMetaModel variableId, Object entity); - @Override - public boolean equals(Object o) { - if (o instanceof AffectedShadowVariable other) { - return nodeId == other.nodeId; - } - return false; - } + void addEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to); - @Override - public int hashCode() { - return nodeId; - } - } + void removeEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to); - private PriorityQueue createInitialChangeQueue() { - var heap = new PriorityQueue(instanceList.size()); - // BitSet iteration: get the first set bit at or after 0, - // then get the first set bit after that bit. - // Iteration ends when nextSetBit returns -1. - // This has the potential to overflow, since to do the - // test, we necessarily need to do nextSetBit(i + 1), - // and i + 1 can be negative if Integer.MAX_VALUE is set - // in the BitSet. - // This should never happen, since arrays in Java are limited - // to slightly less than Integer.MAX_VALUE. - for (var i = changed.nextSetBit(0); i >= 0; i = changed.nextSetBit(i + 1)) { - var topologicalOrder = graph.getTopologicalOrder(i); - heap.add(new AffectedShadowVariable(i, topologicalOrder)); - if (i == Integer.MAX_VALUE) { - break; // or (i+1) would overflow - } - } - changed.clear(); - return heap; - } + void markChanged(@NonNull EntityVariablePair node); - @SuppressWarnings("unchecked") - private void updateInvalidityStatusOfAffectedEntities(Set affectedEntities, LoopedTracker loopedTracker) { - for (var affectedEntity : affectedEntities) { - var shadowVariableLoopedDescriptor = affectedEntity.variableUpdaterInfo.shadowVariableLoopedDescriptor(); - if (shadowVariableLoopedDescriptor == null) { - continue; - } - var entity = affectedEntity.entity; - var isEntityLooped = false; - for (var node : entityToVariableReferenceMap.get(entity)) { - if (graph.isLooped(loopedTracker, node.graphNodeId())) { - isEntityLooped = true; - break; - } - } - var oldValue = shadowVariableLoopedDescriptor.getValue(entity); - if (!Objects.equals(oldValue, isEntityLooped)) { - changedVariableNotifier.beforeVariableChanged().accept( - (VariableDescriptor) shadowVariableLoopedDescriptor, - entity); - shadowVariableLoopedDescriptor.setValue(entity, isEntityLooped); - changedVariableNotifier.afterVariableChanged().accept( - (VariableDescriptor) shadowVariableLoopedDescriptor, - entity); - } - } - } + void updateChanged(); - public void beforeVariableChanged(VariableMetaModel variableReference, Object entity) { - if (variableReference.entity().type().isInstance(entity)) { - var updaterList = variableReferenceToBeforeProcessor.getOrDefault(variableReference, Collections.emptyList()); - for (var consumer : updaterList) { - consumer.accept(this, entity); - } - } - } + void beforeVariableChanged(VariableMetaModel variableReference, Object entity); - public void afterVariableChanged(VariableMetaModel variableReference, Object entity) { - if (variableReference.entity().type().isInstance(entity)) { - var updaterList = variableReferenceToAfterProcessor.getOrDefault(variableReference, Collections.emptyList()); - var node = lookupOrNull(variableReference, entity); - if (node != null) { - markChanged(node); - } - for (var consumer : updaterList) { - consumer.accept(this, entity); - } - } - } + void afterVariableChanged(VariableMetaModel variableReference, Object entity); - @Override - public String toString() { - var builder = new StringBuilder("{\n"); - for (int from = 0; from < counts.length; from++) { - var first = true; - for (int to = 0; to < counts.length; to++) { - if (counts[from][to] != 0) { - if (first) { - first = false; - builder.append(" \"").append(instanceList.get(from)).append("\": ["); - } else { - builder.append(", "); - } - builder.append("\"%s\"".formatted(instanceList.get(to))); - } - } - if (!first) { - builder.append("],\n"); - } - } - builder.append("}"); - return builder.toString(); - } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraphBuilder.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraphBuilder.java new file mode 100644 index 00000000000..2e635c5bde4 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableReferenceGraphBuilder.java @@ -0,0 +1,87 @@ +package ai.timefold.solver.core.impl.domain.variable.declarative; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.IntFunction; + +import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; + +import org.jspecify.annotations.NonNull; + +public final class VariableReferenceGraphBuilder { + + final ChangedVariableNotifier changedVariableNotifier; + final Map, List, Object>>> variableReferenceToBeforeProcessor; + final Map, List, Object>>> variableReferenceToAfterProcessor; + final List> instanceList; + final Map, List>> fixedEdges; + final Map, Map>> variableReferenceToInstanceMap; + + public VariableReferenceGraphBuilder(ChangedVariableNotifier changedVariableNotifier) { + this.changedVariableNotifier = changedVariableNotifier; + instanceList = new ArrayList<>(); + variableReferenceToInstanceMap = new HashMap<>(); + variableReferenceToBeforeProcessor = new HashMap<>(); + variableReferenceToAfterProcessor = new HashMap<>(); + fixedEdges = new HashMap<>(); + } + + public void addVariableReferenceEntity(Entity_ entity, VariableUpdaterInfo variableReference) { + var variableId = variableReference.id(); + var instanceMap = variableReferenceToInstanceMap.get(variableId); + var instance = instanceMap == null ? null : instanceMap.get(entity); + if (instance != null) { + return; + } + if (instanceMap == null) { + instanceMap = new IdentityHashMap<>(); + variableReferenceToInstanceMap.put(variableId, instanceMap); + } + var node = new EntityVariablePair<>(entity, variableReference, instanceList.size()); + instanceMap.put(entity, node); + instanceList.add(node); + } + + public void addFixedEdge(@NonNull EntityVariablePair from, @NonNull EntityVariablePair to) { + if (from.graphNodeId() == to.graphNodeId()) { + return; + } + fixedEdges.computeIfAbsent(from, k -> new ArrayList<>()).add(to); + } + + public void addBeforeProcessor(VariableMetaModel variableId, + BiConsumer, Object> consumer) { + variableReferenceToBeforeProcessor.computeIfAbsent(variableId, k -> new ArrayList<>()) + .add(consumer); + } + + public void addAfterProcessor(VariableMetaModel variableId, + BiConsumer, Object> consumer) { + variableReferenceToAfterProcessor.computeIfAbsent(variableId, k -> new ArrayList<>()) + .add(consumer); + } + + @SuppressWarnings("unchecked") + public VariableReferenceGraph build(IntFunction graphCreator) { + // TODO empty shows up in VRP example when using it as CVRP, not CVRPTW + // In that case, TimeWindowedCustomer does not exist + // and therefore Customer has no shadow variable. + // Surely there has to be an earlier way to catch this? + return instanceList.isEmpty() ? EmptyVariableReferenceGraph.INSTANCE + : new DefaultVariableReferenceGraph<>(this, graphCreator); + } + + public @NonNull EntityVariablePair lookupOrError(VariableMetaModel variableId, Object entity) { + var out = variableReferenceToInstanceMap.getOrDefault(variableId, Collections.emptyMap()).get(entity); + if (out == null) { + throw new IllegalArgumentException(); + } + return out; + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableSourceReference.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableSourceReference.java index 7a78519d693..88311a3ed93 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableSourceReference.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableSourceReference.java @@ -1,8 +1,7 @@ package ai.timefold.solver.core.impl.domain.variable.declarative; import java.util.List; -import java.util.function.BiConsumer; -import java.util.function.Consumer; +import java.util.function.Function; import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor; import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; @@ -18,8 +17,14 @@ public record VariableSourceReference(VariableMetaModel variableMetaMod boolean isDeclarative, VariableMetaModel targetVariableMetamodel, @Nullable VariableMetaModel downstreamDeclarativeVariableMetamodel, - BiConsumer> targetEntityFunctionStartingFromVariableEntity) { + Function targetEntityFunctionStartingFromVariableEntity) { + public boolean affectGraphEdges() { return downstreamDeclarativeVariableMetamodel != null; } + + public @Nullable Object findTargetEntity(Object entity) { + return targetEntityFunctionStartingFromVariableEntity.apply(entity); + } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableUpdaterInfo.java b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableUpdaterInfo.java index 196f00a1597..b0447948ef7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableUpdaterInfo.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/domain/variable/declarative/VariableUpdaterInfo.java @@ -3,14 +3,16 @@ import java.util.function.Function; import ai.timefold.solver.core.impl.domain.common.accessor.MemberAccessor; +import ai.timefold.solver.core.preview.api.domain.metamodel.VariableMetaModel; import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; @NullMarked -public record VariableUpdaterInfo( - DeclarativeShadowVariableDescriptor variableDescriptor, - @Nullable ShadowVariableLoopedVariableDescriptor shadowVariableLoopedDescriptor, +public record VariableUpdaterInfo( + VariableMetaModel id, + DeclarativeShadowVariableDescriptor variableDescriptor, + @Nullable ShadowVariableLoopedVariableDescriptor shadowVariableLoopedDescriptor, MemberAccessor memberAccessor, Function calculator) { } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/util/DynamicIntArray.java b/core/src/main/java/ai/timefold/solver/core/impl/util/DynamicIntArray.java new file mode 100644 index 00000000000..ca89231a9c6 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/util/DynamicIntArray.java @@ -0,0 +1,207 @@ +package ai.timefold.solver.core.impl.util; + +import java.util.Arrays; + +/** + * A class representing an int array that is dynamically allocated based on the first set index. + * The array is only created when the first element is set and is reallocated as needed + * when lower indices are accessed. + */ +public final class DynamicIntArray { + + // Growth factor for array expansion; not too much, the point of this class is to avoid excessive memory use. + private static final double GROWTH_FACTOR = 1.2; + // Minimum capacity increment to avoid small incremental growth + private static final int MIN_CAPACITY_INCREMENT = 10; + + private final int maxLength; + private int[] array; + private int firstIndex; + private int lastIndex; + + public DynamicIntArray() { + this(Integer.MAX_VALUE); + } + + /** + * Creates a new empty DynamicIntArray. + */ + public DynamicIntArray(int maxLength) { + this.maxLength = maxLength; + // Array is null until first element is set + this.array = null; + // Initialize with invalid indices + this.firstIndex = Integer.MAX_VALUE; + this.lastIndex = Integer.MIN_VALUE; + } + + /** + * Sets the value at the specified index. + * If this is the first element, the array is created. + * If the index is lower than the current firstIndex or higher than the current lastIndex, + * the array is reallocated with a growth strategy to reduce frequent reallocations. + * + * @param index the index at which to set the value + * @param value the value to set + */ + public void set(int index, int value) { + if (index < 0 || index >= maxLength) { + throw new ArrayIndexOutOfBoundsException(index); + } + if (array == null) { + // First element, create the array with initial capacity + var initialCapacity = Math.min(MIN_CAPACITY_INCREMENT, maxLength); + array = new int[initialCapacity]; + firstIndex = index; + lastIndex = index; + array[0] = value; + } else if (index < firstIndex) { + // New index is lower than first index, need to reallocate + var currentSize = lastIndex - firstIndex + 1; + var offset = firstIndex - index; + + // Calculate new capacity with growth strategy + var requiredCapacity = currentSize + offset; + var newCapacity = calculateNewCapacity(requiredCapacity); + + // Copy existing elements to new array with offset + var newArray = new int[newCapacity]; + System.arraycopy(array, 0, newArray, offset, currentSize); + array = newArray; + firstIndex = index; + array[0] = value; + } else if (index > lastIndex) { + // New index is higher than last index, need to expand + var currentSize = lastIndex - firstIndex + 1; + var newSize = index - firstIndex + 1; + + if (newSize > array.length) { + // Calculate new capacity with growth strategy + var newCapacity = calculateNewCapacity(newSize); + + // Copy existing elements to new array + var newArray = new int[newCapacity]; + System.arraycopy(array, 0, newArray, 0, currentSize); + array = newArray; + } + + // Update last index + lastIndex = index; + array[index - firstIndex] = value; + } else { + // Index is within existing range + array[index - firstIndex] = value; + } + } + + /** + * Calculates the new capacity based on the required capacity and growth strategy. + * + * @param requiredCapacity the minimum capacity needed + * @return the new capacity + */ + private int calculateNewCapacity(int requiredCapacity) { + var currentCapacity = array != null ? array.length : 0; + + if (requiredCapacity <= currentCapacity) { + return currentCapacity; + } + + // Calculate new capacity using growth factor + var newCapacity = (int) (currentCapacity * GROWTH_FACTOR); + + // Ensure minimum increment + if (newCapacity - currentCapacity < MIN_CAPACITY_INCREMENT) { + newCapacity = currentCapacity + MIN_CAPACITY_INCREMENT; + } + + // Ensure new capacity is at least the required capacity + if (newCapacity < requiredCapacity) { + newCapacity = requiredCapacity; + } + + // Ensure new capacity doesn't exceed maxLength + return Math.min(newCapacity, maxLength); + } + + /** + * Gets the value at the specified index. + * + * @param index the index from which to get the value + * @return the value at the index + * @throws IndexOutOfBoundsException if the index is out of bounds + */ + public int get(int index) { + if (index < 0 || index >= maxLength) { + throw new ArrayIndexOutOfBoundsException(index); + } + if (array == null || index < firstIndex || index > lastIndex) { + return 0; + } + return array[index - firstIndex]; + } + + /** + * Checks if the array contains the specified index. + * + * @param index the index to check + * @return true if the index is within bounds, false otherwise + */ + boolean containsIndex(int index) { + return array != null && index >= firstIndex && index <= lastIndex; + } + + /** + * Gets the first index of the array. + * + * @return the first index + * @throws IllegalStateException if the array is empty + */ + int getFirstIndex() { + if (array == null) { + throw new IllegalStateException("Array is empty"); + } + return firstIndex; + } + + /** + * Gets the last index of the array. + * + * @return the last index + * @throws IllegalStateException if the array is empty + */ + int getLastIndex() { + if (array == null) { + throw new IllegalStateException("Array is empty"); + } + return lastIndex; + } + + /** + * Gets the length of the array. + * + * @return the length of the array, or 0 if the array is empty + */ + int length() { + if (array == null) { + return 0; + } + return lastIndex + 1; + } + + /** + * Clears the array by setting all values to 0. + * The array structure is preserved, only the values are reset. + */ + public void clear() { + // If array is null, there's nothing to clear + if (array == null) { + return; + } + + // Only clear the used portion of the array (from firstIndex to lastIndex) + // This is more efficient for large arrays with sparse indices + Arrays.fill(array, 0, lastIndex - firstIndex + 1, 0); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSet.java b/core/src/main/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSet.java index 2ce3bf840fb..30d63989f06 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSet.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSet.java @@ -131,16 +131,16 @@ public boolean remove(Object o) { if (belowThreshold) { return list.remove(o); } else { - int newSize = set.size() - 1; + if (!set.remove(o)) { + return false; + } + int newSize = set.size(); if (newSize <= LIST_SIZE_THRESHOLD) { - set.remove(o); list = new ArrayList<>(set); set = null; belowThreshold = true; - return true; - } else { - return set.remove(o); } + return true; } } @@ -165,4 +165,9 @@ public void clear() { } } + @Override + public String toString() { + return belowThreshold ? list.toString() : set.toString(); + } + } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSourceTest.java b/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSourceTest.java index d3adb9c7e0a..443e3a5f66a 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSourceTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/declarative/RootVariableSourceTest.java @@ -85,12 +85,10 @@ void pathUsingBuiltinShadow() { assertThat(source.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(source.downstreamDeclarativeVariableMetamodel()).isNull(); - var sourceVisitor = mock(Consumer.class); var entity = new TestdataInvalidDeclarativeValue("v1"); - source.targetEntityFunctionStartingFromVariableEntity().accept(entity, sourceVisitor); + var result = source.targetEntityFunctionStartingFromVariableEntity().apply(entity); - verify(sourceVisitor).accept(entity); - verifyNoMoreInteractions(sourceVisitor); + assertThat(result).isSameAs(entity); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(entity, rootVisitor); @@ -120,12 +118,9 @@ void pathUsingDeclarativeShadow() { assertThat(source.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(source.downstreamDeclarativeVariableMetamodel()).isEqualTo(dependencyMetaModel); - var sourceVisitor = mock(Consumer.class); var entity = new TestdataInvalidDeclarativeValue("v1"); - source.targetEntityFunctionStartingFromVariableEntity().accept(entity, sourceVisitor); - - verify(sourceVisitor).accept(entity); - verifyNoMoreInteractions(sourceVisitor); + var result = source.targetEntityFunctionStartingFromVariableEntity().apply(entity); + assertThat(result).isSameAs(entity); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(entity, rootVisitor); @@ -155,16 +150,13 @@ void pathUsingDeclarativeShadowAfterGroup() { assertThat(source.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(source.downstreamDeclarativeVariableMetamodel()).isEqualTo(dependencyMetaModel); - var sourceVisitor = mock(Consumer.class); var group = new TestdataInvalidDeclarativeValue("group"); var v1 = new TestdataInvalidDeclarativeValue("v1"); var v2 = new TestdataInvalidDeclarativeValue("v2"); group.setGroup(List.of(v1, v2)); - source.targetEntityFunctionStartingFromVariableEntity().accept(group, sourceVisitor); - - verify(sourceVisitor).accept(group); - verifyNoMoreInteractions(sourceVisitor); + var result = source.targetEntityFunctionStartingFromVariableEntity().apply(group); + assertThat(result).isSameAs(group); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(group, rootVisitor); @@ -195,16 +187,13 @@ void pathUsingBuiltinShadowAfterGroup() { assertThat(source.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(source.downstreamDeclarativeVariableMetamodel()).isNull(); - var sourceVisitor = mock(Consumer.class); var group = new TestdataInvalidDeclarativeValue("group"); var v1 = new TestdataInvalidDeclarativeValue("v1"); var v2 = new TestdataInvalidDeclarativeValue("v2"); group.setGroup(List.of(v1, v2)); - source.targetEntityFunctionStartingFromVariableEntity().accept(group, sourceVisitor); - - verify(sourceVisitor).accept(group); - verifyNoMoreInteractions(sourceVisitor); + var result = source.targetEntityFunctionStartingFromVariableEntity().apply(group); + assertThat(result).isSameAs(group); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(group, rootVisitor); @@ -235,7 +224,6 @@ void pathUsingDeclarativeShadowAfterGroupAfterFact() { assertThat(source.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(source.downstreamDeclarativeVariableMetamodel()).isEqualTo(dependencyMetaModel); - var sourceVisitor = mock(Consumer.class); var root = new TestdataInvalidDeclarativeValue("fact"); var fact = new TestdataInvalidDeclarativeValue("fact"); var v1 = new TestdataInvalidDeclarativeValue("v1"); @@ -243,10 +231,8 @@ void pathUsingDeclarativeShadowAfterGroupAfterFact() { root.setFact(fact); fact.setGroup(List.of(v1, v2)); - source.targetEntityFunctionStartingFromVariableEntity().accept(root, sourceVisitor); - - verify(sourceVisitor).accept(root); - verifyNoMoreInteractions(sourceVisitor); + var result = source.targetEntityFunctionStartingFromVariableEntity().apply(root); + assertThat(result).isSameAs(root); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(root, rootVisitor); @@ -286,15 +272,12 @@ void pathUsingDeclarativeShadowAfterBuiltinShadow() { assertThat(dependencySource.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(dependencySource.downstreamDeclarativeVariableMetamodel()).isEqualTo(dependencyMetaModel); - var sourceVisitor = mock(Consumer.class); var previousElement = new TestdataInvalidDeclarativeValue("previous"); var currentElement = new TestdataInvalidDeclarativeValue("current"); currentElement.setPrevious(previousElement); - previousSource.targetEntityFunctionStartingFromVariableEntity().accept(currentElement, sourceVisitor); - - verify(sourceVisitor).accept(previousElement); - verifyNoMoreInteractions(sourceVisitor); + var result = previousSource.targetEntityFunctionStartingFromVariableEntity().apply(currentElement); + assertThat(result).isSameAs(previousElement); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(currentElement, rootVisitor); @@ -333,7 +316,6 @@ void pathUsingDeclarativeShadowAfterBuiltinShadowAfterGroup() { assertThat(dependencySource.targetVariableMetamodel()).isEqualTo(shadowVariableMetaModel); assertThat(dependencySource.downstreamDeclarativeVariableMetamodel()).isEqualTo(dependencyMetaModel); - var sourceVisitor = mock(Consumer.class); var previousElement = new TestdataInvalidDeclarativeValue("previous"); var currentElement = new TestdataInvalidDeclarativeValue("current"); var group = new TestdataInvalidDeclarativeValue("group"); @@ -341,10 +323,8 @@ void pathUsingDeclarativeShadowAfterBuiltinShadowAfterGroup() { currentElement.setPrevious(previousElement); group.setGroup(List.of(currentElement)); - previousSource.targetEntityFunctionStartingFromVariableEntity().accept(currentElement, sourceVisitor); - - verify(sourceVisitor).accept(previousElement); - verifyNoMoreInteractions(sourceVisitor); + var result = previousSource.targetEntityFunctionStartingFromVariableEntity().apply(currentElement); + assertThat(result).isSameAs(previousElement); var rootVisitor = mock(Consumer.class); rootVariableSource.valueEntityFunction().accept(group, rootVisitor); diff --git a/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/listener/support/VariableListenerSupportTest.java b/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/listener/support/VariableListenerSupportTest.java index 5efdd8bdaae..2af87cb6775 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/listener/support/VariableListenerSupportTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/domain/variable/listener/support/VariableListenerSupportTest.java @@ -173,9 +173,11 @@ public MockTopologicalOrderGraph(int size) { } @Override - public void withNodeData(List nodes) { + public void withNodeData(List> nodes) { nodeToEntities = nodes.stream().map(EntityVariablePair::entity).toArray(Object[]::new); - nodeToVariableMetamodel = nodes.stream().map(EntityVariablePair::variableId).toArray(VariableMetaModel[]::new); + nodeToVariableMetamodel = nodes.stream() + .map(e -> e.variableReference().id()) + .toArray(VariableMetaModel[]::new); } public void addEdge(VariableMetaModel fromId, Object fromEntity, VariableMetaModel toId, @@ -189,15 +191,17 @@ public void removeEdge(VariableMetaModel fromId, Object fromEntity, Var } @Override - public void addEdge(int from, int to) { - super.addEdge(from, to); - addEdge(nodeToVariableMetamodel[from], nodeToEntities[from], nodeToVariableMetamodel[to], nodeToEntities[to]); + public void addEdge(int fromNode, int toNode) { + super.addEdge(fromNode, toNode); + addEdge(nodeToVariableMetamodel[fromNode], nodeToEntities[fromNode], nodeToVariableMetamodel[toNode], + nodeToEntities[toNode]); } @Override - public void removeEdge(int from, int to) { - super.addEdge(from, to); - removeEdge(nodeToVariableMetamodel[from], nodeToEntities[from], nodeToVariableMetamodel[to], nodeToEntities[to]); + public void removeEdge(int fromNode, int toNode) { + super.removeEdge(fromNode, toNode); + removeEdge(nodeToVariableMetamodel[fromNode], nodeToEntities[fromNode], nodeToVariableMetamodel[toNode], + nodeToEntities[toNode]); } } diff --git a/core/src/test/java/ai/timefold/solver/core/impl/util/DynamicIntArrayTest.java b/core/src/test/java/ai/timefold/solver/core/impl/util/DynamicIntArrayTest.java new file mode 100644 index 00000000000..47cf49b05b7 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/util/DynamicIntArrayTest.java @@ -0,0 +1,396 @@ +package ai.timefold.solver.core.impl.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class DynamicIntArrayTest { + + @Nested + @DisplayName("Constructor tests") + class ConstructorTests { + + @Test + @DisplayName("Default constructor initializes with max size Integer.MAX_VALUE") + void defaultConstructor() { + var array = new DynamicIntArray(); + + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(array::getFirstIndex) + .withMessage("Array is empty"); + + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(array::getLastIndex) + .withMessage("Array is empty"); + + assertThat(array.length()).isZero(); + assertThat(array.containsIndex(0)).isFalse(); + assertThat(array.get(0)).isZero(); + } + + @Test + @DisplayName("Constructor with maxSize initializes correctly") + void constructorWithMaxSize() { + var array = new DynamicIntArray(100); + + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(array::getFirstIndex); + + assertThat(array.length()).isZero(); + + // Test bound checking with maxSize + assertThatExceptionOfType(ArrayIndexOutOfBoundsException.class) + .isThrownBy(() -> array.set(100, 5)); + } + } + + @Nested + @DisplayName("Set method tests") + class SetMethodTests { + + @Test + @DisplayName("Set first element initializes the array") + void setFirstElement() { + var array = new DynamicIntArray(); + array.set(10, 42); + + assertThat(array.get(10)).isEqualTo(42); + assertThat(array.getFirstIndex()).isEqualTo(10); + assertThat(array.getLastIndex()).isEqualTo(10); + assertThat(array.containsIndex(10)).isTrue(); + assertThat(array.containsIndex(9)).isFalse(); + assertThat(array.containsIndex(11)).isFalse(); + assertThat(array.length()).isEqualTo(11); // 0-10 inclusive + } + + @Test + @DisplayName("Set lower index than first index reallocates the array") + void setLowerIndex() { + var array = new DynamicIntArray(); + array.set(10, 42); + array.set(5, 24); + + assertThat(array.get(5)).isEqualTo(24); + assertThat(array.get(10)).isEqualTo(42); + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(10); + assertThat(array.containsIndex(5)).isTrue(); + assertThat(array.containsIndex(10)).isTrue(); + assertThat(array.length()).isEqualTo(11); // 0-10 inclusive + } + + @Test + @DisplayName("Set higher index than last index expands the array") + void setHigherIndex() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + + assertThat(array.get(5)).isEqualTo(24); + assertThat(array.get(10)).isEqualTo(42); + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(10); + assertThat(array.containsIndex(5)).isTrue(); + assertThat(array.containsIndex(10)).isTrue(); + assertThat(array.length()).isEqualTo(11); // 0-10 inclusive + } + + @Test + @DisplayName("Set existing index updates the value") + void setExistingIndex() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + array.set(7, 99); + array.set(7, 100); // Update existing value + + assertThat(array.get(7)).isEqualTo(100); + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(10); + } + + @ParameterizedTest + @ValueSource(ints = { -1, -5, -100 }) + @DisplayName("Set negative index throws ArrayIndexOutOfBoundsException") + void setNegativeIndex(int index) { + var array = new DynamicIntArray(); + + assertThatExceptionOfType(ArrayIndexOutOfBoundsException.class) + .isThrownBy(() -> array.set(index, 42)); + } + + @Test + @DisplayName("Set index greater than or equal to maxSize throws ArrayIndexOutOfBoundsException") + void setIndexGreaterThanMaxSize() { + var array = new DynamicIntArray(50); + + assertThatExceptionOfType(ArrayIndexOutOfBoundsException.class) + .isThrownBy(() -> array.set(50, 42)); + + assertThatExceptionOfType(ArrayIndexOutOfBoundsException.class) + .isThrownBy(() -> array.set(100, 42)); + } + } + + @Nested + @DisplayName("Get method tests") + class GetMethodTests { + + @Test + @DisplayName("Get returns 0 for empty array") + void getFromEmptyArray() { + var array = new DynamicIntArray(); + + assertThat(array.get(5)).isZero(); + } + + @Test + @DisplayName("Get returns 0 for index lower than first index") + void getIndexLowerThanFirstIndex() { + var array = new DynamicIntArray(); + array.set(10, 42); + + assertThat(array.get(5)).isZero(); + } + + @Test + @DisplayName("Get returns 0 for index higher than last index") + void getIndexHigherThanLastIndex() { + var array = new DynamicIntArray(); + array.set(10, 42); + + assertThat(array.get(15)).isZero(); + } + + @Test + @DisplayName("Get returns correct value for existing index") + void getExistingIndex() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + + assertThat(array.get(5)).isEqualTo(24); + assertThat(array.get(10)).isEqualTo(42); + } + } + + @Nested + @DisplayName("ContainsIndex method tests") + class ContainsIndexMethodTests { + + @Test + @DisplayName("ContainsIndex returns false for empty array") + void containsIndexEmptyArray() { + var array = new DynamicIntArray(); + + assertThat(array.containsIndex(0)).isFalse(); + assertThat(array.containsIndex(5)).isFalse(); + } + + @Test + @DisplayName("ContainsIndex returns true for indices within range") + void containsIndexWithinRange() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + + assertThat(array.containsIndex(5)).isTrue(); + assertThat(array.containsIndex(7)).isTrue(); + assertThat(array.containsIndex(10)).isTrue(); + } + + @Test + @DisplayName("ContainsIndex returns false for indices outside range") + void containsIndexOutsideRange() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + + assertThat(array.containsIndex(4)).isFalse(); + assertThat(array.containsIndex(11)).isFalse(); + } + } + + @Nested + @DisplayName("Length method tests") + class LengthMethodTests { + + @Test + @DisplayName("Length returns 0 for empty array") + void lengthEmptyArray() { + var array = new DynamicIntArray(); + + assertThat(array.length()).isZero(); + } + + @Test + @DisplayName("Length returns correct value after setting elements") + void lengthAfterSettingElements() { + var array = new DynamicIntArray(); + array.set(0, 1); + + assertThat(array.length()).isEqualTo(1); + + array.set(5, 24); + assertThat(array.length()).isEqualTo(6); // 0-5 inclusive + + array.set(10, 42); + assertThat(array.length()).isEqualTo(11); // 0-10 inclusive + } + } + + @Nested + @DisplayName("Complex scenario tests") + class ComplexScenarioTests { + + @Test + @DisplayName("Test multiple operations in sequence") + void testMultipleOperations() { + var array = new DynamicIntArray(); + + // Initial setup + array.set(10, 42); + assertThat(array.get(10)).isEqualTo(42); + assertThat(array.getFirstIndex()).isEqualTo(10); + assertThat(array.getLastIndex()).isEqualTo(10); + + // Expand below + array.set(5, 24); + assertThat(array.get(5)).isEqualTo(24); + assertThat(array.get(10)).isEqualTo(42); + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(10); + + // Expand above + array.set(15, 99); + assertThat(array.get(5)).isEqualTo(24); + assertThat(array.get(10)).isEqualTo(42); + assertThat(array.get(15)).isEqualTo(99); + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(15); + + // Update existing + array.set(10, 100); + assertThat(array.get(10)).isEqualTo(100); + + // Verify indices not explicitly set + assertThat(array.get(7)).isZero(); + assertThat(array.get(12)).isZero(); + + // Verify contains index + assertThat(array.containsIndex(5)).isTrue(); + assertThat(array.containsIndex(7)).isTrue(); + assertThat(array.containsIndex(15)).isTrue(); + assertThat(array.containsIndex(4)).isFalse(); + assertThat(array.containsIndex(16)).isFalse(); + + // Verify length + assertThat(array.length()).isEqualTo(16); // 0-15 inclusive + } + + @Test + @DisplayName("Test with sparse indices") + void testWithSparseIndices() { + var array = new DynamicIntArray(); + + array.set(100, 1); + array.set(1000, 2); + array.set(10, 3); + + assertThat(array.getFirstIndex()).isEqualTo(10); + assertThat(array.getLastIndex()).isEqualTo(1000); + assertThat(array.get(10)).isEqualTo(3); + assertThat(array.get(100)).isEqualTo(1); + assertThat(array.get(1000)).isEqualTo(2); + assertThat(array.length()).isEqualTo(1001); // 0-1000 inclusive + } + } + + @Nested + @DisplayName("Clear method tests") + class ClearMethodTests { + + @Test + @DisplayName("Clear on empty array does nothing") + void clearEmptyArray() { + var array = new DynamicIntArray(); + + // Should not throw an exception + array.clear(); + + assertThat(array.length()).isZero(); + } + + @Test + @DisplayName("Clear resets all values to 0 but preserves array structure") + void clearResetsValues() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + + array.clear(); + + // Values should be reset to 0 + assertThat(array.get(5)).isZero(); + assertThat(array.get(10)).isZero(); + + // Array structure should be preserved + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(10); + assertThat(array.containsIndex(5)).isTrue(); + assertThat(array.containsIndex(10)).isTrue(); + assertThat(array.length()).isEqualTo(11); // 0-10 inclusive + } + + @Test + @DisplayName("Clear and then set new values") + void clearAndSetNewValues() { + var array = new DynamicIntArray(); + array.set(5, 24); + array.set(10, 42); + + array.clear(); + + // Set new values + array.set(7, 99); + + // New values should be set correctly + assertThat(array.get(7)).isEqualTo(99); + + // Old indices should still be in the array but with value 0 + assertThat(array.get(5)).isZero(); + assertThat(array.get(10)).isZero(); + + // Array structure should be updated + assertThat(array.getFirstIndex()).isEqualTo(5); + assertThat(array.getLastIndex()).isEqualTo(10); + } + + @Test + @DisplayName("Clear with sparse indices") + void clearWithSparseIndices() { + var array = new DynamicIntArray(); + array.set(10, 1); + array.set(100, 2); + array.set(1000, 3); + + array.clear(); + + // All values should be reset to 0 + assertThat(array.get(10)).isZero(); + assertThat(array.get(100)).isZero(); + assertThat(array.get(1000)).isZero(); + + // Array structure should be preserved + assertThat(array.getFirstIndex()).isEqualTo(10); + assertThat(array.getLastIndex()).isEqualTo(1000); + assertThat(array.length()).isEqualTo(1001); // 0-1000 inclusive + } + } + +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSetTest.java b/core/src/test/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSetTest.java index 895b04d5a0c..19342d0bef0 100644 --- a/core/src/test/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSetTest.java +++ b/core/src/test/java/ai/timefold/solver/core/impl/util/ListBasedScalingOrderedSetTest.java @@ -1,84 +1,316 @@ package ai.timefold.solver.core.impl.util; -import static ai.timefold.solver.core.impl.util.ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.util.Arrays; -import java.util.Set; +import java.util.List; import org.junit.jupiter.api.Test; class ListBasedScalingOrderedSetTest { @Test - void addRemoveAroundThreshold() { - Set set = new ListBasedScalingOrderedSet<>(); - assertThat(set.add("s1")).isTrue(); - assertThat(set.add("s1")).isFalse(); - assertThat(set.add("s2")).isTrue(); - assertThat(set.add("s1")).isFalse(); - assertThat(set.add("s2")).isFalse(); - assertThat(set.remove("s2")).isTrue(); - assertThat(set.remove("s2")).isFalse(); - assertThat(set.add("s2")).isTrue(); + void emptySetProperties() { + var set = new ListBasedScalingOrderedSet(); + + assertThat(set) + .doesNotContain("test") + .isEmpty(); + } + + @Test + void addSingleElement() { + var set = new ListBasedScalingOrderedSet(); + + var changed = set.add("test"); + + assertThat(changed).isTrue(); + assertThat(set) + .hasSize(1) + .contains("test"); + } + + @Test + void addDuplicateElement() { + var set = new ListBasedScalingOrderedSet(); + + set.add("test"); + var changed = set.add("test"); + + assertThat(changed).isFalse(); + assertThat(set) + .hasSize(1) + .containsExactly("test"); + } + + @Test + void addAllWithNewElements() { + var set = new ListBasedScalingOrderedSet(); + + var changed = set.addAll(Arrays.asList("a", "b", "c")); + + assertThat(changed).isTrue(); + assertThat(set) + .hasSize(3) + .containsExactly("a", "b", "c"); + } + + @Test + void addAllWithDuplicateElements() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + var changed = set.addAll(Arrays.asList("b", "c")); + + assertThat(changed).isTrue(); + assertThat(set) + .hasSize(3) + .containsExactly("a", "b", "c"); + } + + @Test + void addAllWithAllDuplicateElements() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + var changed = set.addAll(Arrays.asList("a", "b")); + + assertThat(changed).isFalse(); assertThat(set) .hasSize(2) - .containsExactlyInAnyOrder("s1", "s2"); + .containsExactly("a", "b"); + } - for (int i = 0; i < LIST_SIZE_THRESHOLD - 3; i++) { - set.add("filler " + i); - } - assertThat(set.add("s2")).isFalse(); - assertThat(set.add("s3")).isTrue(); - assertThat(set.add("s2")).isFalse(); - assertThat(set).hasSize(LIST_SIZE_THRESHOLD); - assertThat(set.add("s4")).isTrue(); - assertThat(set.add("s2")).isFalse(); - assertThat(set.add("s3")).isFalse(); - assertThat(set.add("s4")).isFalse(); - assertThat(set).hasSize(LIST_SIZE_THRESHOLD + 1); - assertThat(set.remove("s4")).isTrue(); - assertThat(set.add("s2")).isFalse(); - assertThat(set.add("s3")).isFalse(); - assertThat(set).hasSize(LIST_SIZE_THRESHOLD); - assertThat(set.add("s5")).isTrue(); - assertThat(set.add("s2")).isFalse(); - assertThat(set.add("s3")).isFalse(); - assertThat(set).hasSize(LIST_SIZE_THRESHOLD + 1); - assertThat(set.add("s6")).isTrue(); - assertThat(set.add("s2")).isFalse(); - assertThat(set.add("s3")).isFalse(); + @Test + void removeSingleElement() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + + var removed = set.remove("a"); + + assertThat(removed).isTrue(); + assertThat(set).isEmpty(); + } + + @Test + void removeNonexistentElement() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + + var removed = set.remove("b"); + + assertThat(removed).isFalse(); assertThat(set) - .hasSize(LIST_SIZE_THRESHOLD + 2) - .contains("s1", "s2", "s3", "s5", "s6") - .doesNotContain("s4"); + .hasSize(1) + .contains("a"); + } + + @Test + void clearEmptiesTheSet() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + set.clear(); + + assertThat(set).isEmpty(); + } + + @Test + void toArrayReturnsCorrectArray() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + var array = set.toArray(); + + assertThat(array).containsExactly("a", "b"); + } + + @Test + void toArrayWithTypeReturnsCorrectArray() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + var array = set.toArray(new String[0]); + + assertThat(array).containsExactly("a", "b"); + } + + @Test + void iteratorReturnsAllElements() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + var iterator = set.iterator(); + + assertThat(iterator).hasNext(); + assertThat(iterator.next()).isEqualTo("a"); + assertThat(iterator).hasNext(); + assertThat(iterator.next()).isEqualTo("b"); + assertThat(iterator.hasNext()).isFalse(); + } + + @Test + void iteratorRemoveThrowsException() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + + var iterator = set.iterator(); + iterator.next(); + + assertThatExceptionOfType(UnsupportedOperationException.class) + .isThrownBy(iterator::remove); } @Test - void addAllAroundThreshold() { - Set set = new ListBasedScalingOrderedSet<>(); - assertThat(set.addAll(Arrays.asList("s1", "s2", "s3"))).isTrue(); - assertThat(set).hasSize(3); - assertThat(set.addAll(Arrays.asList("s1", "s3", "s4", "s5"))).isTrue(); - assertThat(set.addAll(Arrays.asList("s1", "s2", "s4"))).isFalse(); + void containsWorks() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + assertThat(set) - .hasSize(5) - .containsExactlyInAnyOrder("s1", "s2", "s3", "s4", "s5"); + .contains("a") + .doesNotContain("b"); + } + + @Test + void containsAllWorks() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + assertThat(set).containsAll(Arrays.asList("a", "b")); + assertThat(set.containsAll(Arrays.asList("a", "c"))).isFalse(); + } + + @Test + void retainAllThrowsException() { + var set = new ListBasedScalingOrderedSet(); + + var list = List.of("a"); + assertThatThrownBy(() -> set.retainAll(list)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("retainAll()"); + } + + @Test + void removeAllThrowsException() { + var set = new ListBasedScalingOrderedSet(); + + var list = List.of("a"); + assertThatThrownBy(() -> set.removeAll(list)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessageContaining("removeAll()"); + } + + @Test + void toStringWorks() { + var set = new ListBasedScalingOrderedSet(); + set.add("a"); + set.add("b"); + + assertThat(set.toString()).contains("a", "b"); + } + + @Test + void scalingFromListToSet() { + var set = new ListBasedScalingOrderedSet(); + + // Add elements up to the threshold (16) + for (var i = 0; i < ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; i++) { + set.add(i); + } + + // At this point, it should still be using a list + assertThat(set).hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD); + + // Adding one more should cause it to switch to a set + set.add(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD); + + // Verify it still works correctly + assertThat(set).hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD + 1); + for (var i = 0; i <= ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; i++) { + assertThat(set).contains(i); + } + } + + @Test + void scalingFromSetToList() { + var set = new ListBasedScalingOrderedSet(); - for (int i = 0; i < LIST_SIZE_THRESHOLD - 7; i++) { - set.add("filler " + i); + // Add elements beyond threshold to ensure it's using a set + for (var i = 0; i <= ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; i++) { + set.add(i); } - assertThat(set).hasSize(LIST_SIZE_THRESHOLD - 2); - assertThat(set.addAll(Arrays.asList("s6", "s7", "s2", "s3", "s8", "s9"))).isTrue(); - assertThat(set).hasSize(LIST_SIZE_THRESHOLD + 2); - assertThat(set.remove("s1")).isTrue(); - assertThat(set.remove("s5")).isTrue(); - assertThat(set).hasSize(LIST_SIZE_THRESHOLD); - assertThat(set.addAll(Arrays.asList("s1", "s2", "s10"))).isTrue(); + + assertThat(set).hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD + 1); + + // Remove elements until we're at threshold + set.remove(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD); + + // At threshold, it should still be a set + assertThat(set).hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD); + + // Remove one more to trigger scaling back to list + set.remove(0); + + // Verify it still works correctly assertThat(set) - .hasSize(LIST_SIZE_THRESHOLD + 2) - .contains("s1", "s2", "s3", "s4", "s6", "s7", "s8", "s9", "s10") - .doesNotContain("s5"); + .hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 1) + .doesNotContain(0); + for (var i = 1; i < ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; i++) { + assertThat(set).contains(i); + } } -} + @Test + void addAllCausingScaling() { + var set = new ListBasedScalingOrderedSet(); + + // Add some elements but stay below threshold + for (var i = 0; i < ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 5; i++) { + set.add(i); + } + + // Prepare a collection that will push it over threshold when added + var toAdd = List.of( + ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 5, + ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 4, + ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 3, + ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 2, + ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD - 1, + ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD); + + // Add the collection, which should trigger scaling + var changed = set.addAll(toAdd); + + assertThat(changed).isTrue(); + assertThat(set).hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD + 1); + for (var i = 0; i <= ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; i++) { + assertThat(set).contains(i); + } + } + + @Test + void attemptToRemoveNonExistentElementFromSet() { + var set = new ListBasedScalingOrderedSet(); + + // Add enough elements to use a set internally + for (var i = 0; i <= ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD; i++) { + set.add(i); + } + + // Try to remove an element that doesn't exist + var removed = set.remove(999); + + // Verify element wasn't removed and set didn't change state + assertThat(removed).isFalse(); + assertThat(set).hasSize(ListBasedScalingOrderedSet.LIST_SIZE_THRESHOLD + 1); + } +} \ No newline at end of file