Skip to content

Commit e839b3a

Browse files
committed
fix tests + break up monster func
1 parent 2c84447 commit e839b3a

4 files changed

Lines changed: 95 additions & 82 deletions

File tree

PathPlanning/TimeBasedPathPlanning/ConflictBasedSearch.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Conflict Based Search generates paths in 2 dimensions (x, y, time) for a set of agents. It does
2+
Conflict Based Search generates paths in 3 dimensions (x, y, time) for a set of agents. It does
33
so by performing searches on two levels. The top level search applies constraints that agents
44
must avoid, and the bottom level performs a single-agent search for individual agents given
55
the constraints provided by the top level search. Initially, paths are generated for each agent
@@ -105,14 +105,18 @@ def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_c
105105
constraint_tree.add_node_to_tree(new_constraint_tree_node)
106106

107107
raise RuntimeError("No solution found")
108-
108+
109+
@staticmethod
109110
def get_agents_start_and_goal(start_and_goal_list: list[StartAndGoal], target_index: AgentId) -> StartAndGoal:
111+
"""
112+
Returns the start and goal of a specific agent
113+
"""
110114
for item in start_and_goal_list:
111115
if item.agent_id == target_index:
112116
return item
113117
raise RuntimeError(f"Could not find agent with index {target_index} in {start_and_goal_list}")
114118

115-
119+
@staticmethod
116120
def plan_for_agent(constrained_agent: ConstraintTreeNode,
117121
all_constraints: list[AppliedConstraint],
118122
constraint_tree: ConstraintTree,
@@ -121,7 +125,9 @@ def plan_for_agent(constrained_agent: ConstraintTreeNode,
121125
single_agent_planner_class: SingleAgentPlanner,
122126
start_and_goals: list[StartAndGoal],
123127
verbose: False) -> Optional[tuple[list[StartAndGoal], list[NodePath]]]:
124-
128+
"""
129+
Attempt to generate a path plan for a single agent
130+
"""
125131
num_expansions = constraint_tree.expanded_node_count()
126132
if num_expansions % 50 == 0:
127133
print(f"Expanded {num_expansions} nodes so far...")

PathPlanning/TimeBasedPathPlanning/ConstraintTree.py

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -46,68 +46,76 @@ def __lt__(self, other) -> bool:
4646
return self.cost < other.cost
4747

4848
def get_constraint_point(self, verbose = False) -> Optional[ForkingConstraint]:
49+
"""
50+
Check paths for any constraints, and if any are found return the earliest one.
51+
"""
4952

5053
final_t = max(path.goal_reached_time() for path in self.paths.values())
5154
positions_at_time: dict[PositionAtTime, AgentId] = {}
5255
for t in range(final_t + 1):
53-
possible_constraints: list[ForkingConstraint] = []
54-
for agent_id, path in self.paths.items():
55-
# Check for edge conflicts
56-
last_position = None
57-
if t > 0:
58-
last_position = path.get_position(t - 1)
59-
60-
position = path.get_position(t)
61-
if position is None:
62-
continue
63-
position_at_time = PositionAtTime(position, t)
64-
if position_at_time not in positions_at_time:
65-
positions_at_time[position_at_time] = AgentId(agent_id)
66-
67-
# edge conflict
68-
if last_position:
69-
new_position_at_last_time = PositionAtTime(position, t-1)
70-
old_position_at_new_time = PositionAtTime(last_position, t)
71-
if new_position_at_last_time in positions_at_time and old_position_at_new_time in positions_at_time:
72-
conflicting_agent_id1 = positions_at_time[new_position_at_last_time]
73-
conflicting_agent_id2 = positions_at_time[old_position_at_new_time]
74-
75-
if conflicting_agent_id1 == conflicting_agent_id2 and conflicting_agent_id1 != agent_id:
76-
if verbose:
77-
print(f"Found edge constraint between with agent {conflicting_agent_id1} for {agent_id}")
78-
print(f"\tpositions old: {old_position_at_new_time}, new: {position_at_time}")
79-
new_constraint = ForkingConstraint((
80-
ConstrainedAgent(agent_id, position_at_time),
81-
ConstrainedAgent(conflicting_agent_id1, Constraint(position=last_position, time=t))
82-
))
83-
possible_constraints.append(new_constraint)
84-
continue
85-
86-
# double reservation at a (cell, time) combination
87-
if positions_at_time[position_at_time] != agent_id:
88-
conflicting_agent_id = positions_at_time[position_at_time]
89-
90-
constraint = Constraint(position=position, time=t)
91-
possible_constraints.append(ForkingConstraint((
92-
ConstrainedAgent(agent_id, constraint), ConstrainedAgent(conflicting_agent_id, constraint)
93-
)))
94-
continue
95-
if possible_constraints:
96-
if verbose:
97-
print(f"Choosing best constraint of {possible_constraints}")
98-
# first check for edge constraints
99-
for constraint in possible_constraints:
100-
if constraint.constrained_agents[0].constraint.position != constraint.constrained_agents[1].constraint.position:
101-
if verbose:
102-
print(f"\tFound edge conflict constraint: {constraint}")
103-
return constraint
104-
# if none, then return first normal constraint
105-
if verbose:
106-
print(f"\tReturning normal constraint: {possible_constraints[0]}")
107-
return possible_constraints[0]
56+
possible_constraints: list[ForkingConstraint] = self.check_for_constraints_at_time(positions_at_time, t, verbose)
57+
if not possible_constraints:
58+
continue
59+
60+
if verbose:
61+
print(f"Choosing best constraint of {possible_constraints}")
62+
# first check for edge constraints
63+
for constraint in possible_constraints:
64+
if constraint.constrained_agents[0].constraint.position != constraint.constrained_agents[1].constraint.position:
65+
if verbose:
66+
print(f"\tFound edge conflict constraint: {constraint}")
67+
return constraint
68+
# if none, then return first normal constraint
69+
if verbose:
70+
print(f"\tReturning normal constraint: {possible_constraints[0]}")
71+
return possible_constraints[0]
10872

10973
return None
11074

75+
def check_for_constraints_at_time(self, positions_at_time: dict[PositionAtTime, AgentId], t: int, verbose: bool) -> list[ForkingConstraint]:
76+
"""
77+
Check for constraints between paths at a particular time step
78+
"""
79+
possible_constraints: list[ForkingConstraint] = []
80+
for agent_id, path in self.paths.items():
81+
82+
position = path.get_position(t)
83+
if position is None:
84+
continue
85+
position_at_time = PositionAtTime(position, t)
86+
if position_at_time not in positions_at_time:
87+
positions_at_time[position_at_time] = AgentId(agent_id)
88+
89+
# double reservation at a (cell, time) combination
90+
if positions_at_time[position_at_time] != agent_id:
91+
conflicting_agent_id = positions_at_time[position_at_time]
92+
constraint = Constraint(position=position, time=t)
93+
possible_constraints.append(ForkingConstraint((
94+
ConstrainedAgent(agent_id, constraint), ConstrainedAgent(conflicting_agent_id, constraint)
95+
)))
96+
97+
# Check for edge conflicts (can only happen after first time step)
98+
if t == 0:
99+
continue
100+
last_position = path.get_position(t - 1)
101+
new_position_at_last_time = PositionAtTime(position, t-1)
102+
old_position_at_new_time = PositionAtTime(last_position, t)
103+
if new_position_at_last_time in positions_at_time and old_position_at_new_time in positions_at_time:
104+
conflicting_agent_id1 = positions_at_time[new_position_at_last_time]
105+
conflicting_agent_id2 = positions_at_time[old_position_at_new_time]
106+
107+
if conflicting_agent_id1 == conflicting_agent_id2 and conflicting_agent_id1 != agent_id:
108+
if verbose:
109+
print(f"Found edge constraint between with agent {conflicting_agent_id1} for {agent_id}")
110+
print(f"\tpositions old: {old_position_at_new_time}, new: {position_at_time}")
111+
new_constraint = ForkingConstraint((
112+
ConstrainedAgent(agent_id, position_at_time),
113+
ConstrainedAgent(conflicting_agent_id1, Constraint(position=last_position, time=t))
114+
))
115+
possible_constraints.append(new_constraint)
116+
117+
return possible_constraints
118+
111119

112120
class ConstraintTree:
113121
# Child nodes have been created (Maps node_index to ConstraintTreeNode)

tests/test_conflict_based_search.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ def test_no_constraints(single_agent_planner):
2929
obstacle_arrangement=ObstacleArrangement.NONE,
3030
)
3131

32-
start_and_goals: list[StartAndGoal]
3332
paths: list[NodePath]
34-
start_and_goals, paths = m.ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, False)
33+
paths = m.ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, False)
3534

3635
# All paths should start at the specified position and reach the goal
37-
for i, start_and_goal in enumerate(start_and_goals):
38-
assert paths[i].path[0].position == start_and_goal.start
39-
assert paths[i].path[-1].position == start_and_goal.goal
36+
for start_and_goal in start_and_goals:
37+
assert paths[start_and_goal.agent_id].path[0].position == start_and_goal.start
38+
assert paths[start_and_goal.agent_id].path[-1].position == start_and_goal.goal
4039

4140
@pytest.mark.parametrize("single_agent_planner", [SpaceTimeAStar, SafeIntervalPathPlanner])
4241
def test_narrow_corridor(single_agent_planner):
@@ -53,14 +52,13 @@ def test_narrow_corridor(single_agent_planner):
5352
obstacle_arrangement=ObstacleArrangement.NARROW_CORRIDOR,
5453
)
5554

56-
start_and_goals: list[StartAndGoal]
5755
paths: list[NodePath]
58-
start_and_goals, paths = m.ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, False)
56+
paths = m.ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, False)
5957

6058
# All paths should start at the specified position and reach the goal
61-
for i, start_and_goal in enumerate(start_and_goals):
62-
assert paths[i].path[0].position == start_and_goal.start
63-
assert paths[i].path[-1].position == start_and_goal.goal
59+
for start_and_goal in start_and_goals:
60+
assert paths[start_and_goal.agent_id].path[0].position == start_and_goal.start
61+
assert paths[start_and_goal.agent_id].path[-1].position == start_and_goal.goal
6462

6563
@pytest.mark.parametrize("single_agent_planner", [SpaceTimeAStar, SafeIntervalPathPlanner])
6664
def test_hallway_pass(single_agent_planner: SingleAgentPlanner):
@@ -78,14 +76,13 @@ def test_hallway_pass(single_agent_planner: SingleAgentPlanner):
7876
obstacle_arrangement=ObstacleArrangement.HALLWAY,
7977
)
8078

81-
start_and_goals: list[StartAndGoal]
8279
paths: list[NodePath]
83-
start_and_goals, paths = m.ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, False)
80+
paths = m.ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, False)
8481

8582
# All paths should start at the specified position and reach the goal
86-
for i, start_and_goal in enumerate(start_and_goals):
87-
assert paths[i].path[0].position == start_and_goal.start
88-
assert paths[i].path[-1].position == start_and_goal.goal
83+
for start_and_goal in start_and_goals:
84+
assert paths[start_and_goal.agent_id].path[0].position == start_and_goal.start
85+
assert paths[start_and_goal.agent_id].path[-1].position == start_and_goal.goal
8986

9087
if __name__ == "__main__":
9188
m.show_animation = False

tests/test_priority_based_planner.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66
)
77
from PathPlanning.TimeBasedPathPlanning.BaseClasses import StartAndGoal
88
from PathPlanning.TimeBasedPathPlanning import PriorityBasedPlanner as m
9+
from PathPlanning.TimeBasedPathPlanning.SpaceTimeAStar import SpaceTimeAStar
910
from PathPlanning.TimeBasedPathPlanning.SafeInterval import SafeIntervalPathPlanner
11+
from PathPlanning.TimeBasedPathPlanning.BaseClasses import SingleAgentPlanner
1012
import numpy as np
1113
import conftest
14+
import pytest
1215

13-
14-
def test_1():
16+
@pytest.mark.parametrize("single_agent_planner", [SpaceTimeAStar, SafeIntervalPathPlanner])
17+
def test_1(single_agent_planner: SingleAgentPlanner):
1518
grid_side_length = 21
1619

17-
start_and_goals = [StartAndGoal(i, Position(1, i), Position(19, 19-i)) for i in range(1, 16)]
20+
start_and_goals = [StartAndGoal(i, Position(1, i), Position(19, 19-i)) for i in range(1, 11)]
1821
obstacle_avoid_points = [pos for item in start_and_goals for pos in (item.start, item.goal)]
1922

2023
grid = Grid(
@@ -26,14 +29,13 @@ def test_1():
2629

2730
m.show_animation = False
2831

29-
start_and_goals: list[StartAndGoal]
3032
paths: list[NodePath]
31-
start_and_goals, paths = m.PriorityBasedPlanner.plan(grid, start_and_goals, SafeIntervalPathPlanner, False)
33+
paths = m.PriorityBasedPlanner.plan(grid, start_and_goals, single_agent_planner, False)
3234

3335
# All paths should start at the specified position and reach the goal
34-
for i, start_and_goal in enumerate(start_and_goals):
35-
assert paths[i].path[0].position == start_and_goal.start
36-
assert paths[i].path[-1].position == start_and_goal.goal
36+
for start_and_goal in start_and_goals:
37+
assert paths[start_and_goal.agent_id].path[0].position == start_and_goal.start
38+
assert paths[start_and_goal.agent_id].path[-1].position == start_and_goal.goal
3739

3840
if __name__ == "__main__":
3941
conftest.run_this_test(__file__)

0 commit comments

Comments
 (0)