Skip to content

Commit 2c84447

Browse files
committed
cleaner return type from .plan()
1 parent d822aa7 commit 2c84447

4 files changed

Lines changed: 43 additions & 39 deletions

File tree

PathPlanning/TimeBasedPathPlanning/BaseClasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class MultiAgentPlanner(ABC):
4343

4444
@staticmethod
4545
@abstractmethod
46-
def plan(grid: Grid, start_and_goal_positions: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
46+
def plan(grid: Grid, start_and_goal_positions: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> dict[AgentId, NodePath]:
4747
"""
48-
Plan for all agents. Returned paths are in order corresponding to the returned list of `StartAndGoal` objects
48+
Plan for all agents. Returned paths found for each agent
4949
"""
5050
pass

PathPlanning/TimeBasedPathPlanning/ConflictBasedSearch.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@ class ConflictBasedSearch(MultiAgentPlanner):
3434

3535

3636
@staticmethod
37-
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
37+
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> dict[AgentId, NodePath]:
3838
"""
3939
Generate a path from the start to the goal for each agent in the `start_and_goals` list.
40-
Returns the re-ordered StartAndGoal combinations, and a list of path plans. The order of the plans
41-
corresponds to the order of the `start_and_goals` list.
4240
"""
4341
print(f"Using single-agent planner: {single_agent_planner_class}")
4442

@@ -72,7 +70,10 @@ def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_c
7270
print(f"\nFound a path with constraints after {constraint_tree.expanded_node_count()} expansions")
7371
print(f"Final cost: {constraint_tree_node.cost}")
7472
print(f"Number of constraints on solution: {len(constraint_tree_node.all_constraints)}")
75-
return (start_and_goals, [constraint_tree_node.paths[start_and_goal.agent_id] for start_and_goal in start_and_goals])
73+
final_paths = {}
74+
for start_and_goal in start_and_goals:
75+
final_paths[start_and_goal.agent_id] = constraint_tree_node.paths[start_and_goal.agent_id]
76+
return final_paths
7677

7778
if not isinstance(constraint_tree_node.constraint, ForkingConstraint):
7879
raise ValueError(f"Expected a ForkingConstraint, but got: {constraint_tree_node.constraint}")
@@ -89,7 +90,6 @@ def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_c
8990

9091
# Deepcopy to update with applied constraint and new paths
9192
applied_constraint_parent = deepcopy(constraint_tree_node)
92-
# Copy paths for child node - we just need to update constrained agent's path
9393
applied_constraint_parent.paths[constrained_agent.agent] = new_path
9494

9595
if verbose:
@@ -163,6 +163,7 @@ class Scenario(Enum):
163163
scenario = Scenario.HALLWAY_CROSS
164164
verbose = False
165165
show_animation = True
166+
use_sipp = False # Condition here mainly to appease the linter
166167
np.random.seed(42) # For reproducibility
167168
def main():
168169
grid_side_length = 21
@@ -187,16 +188,16 @@ def main():
187188
obstacle_arrangement=obstacle_arrangement,
188189
)
189190

191+
single_agent_planner = SafeIntervalPathPlanner if use_sipp else SpaceTimeAStar
190192
start_time = time.time()
191-
start_and_goals, paths = ConflictBasedSearch.plan(grid, start_and_goals, SafeIntervalPathPlanner, verbose)
192-
# start_and_goals, paths = ConflictBasedSearch.plan(grid, start_and_goals, SpaceTimeAStar, verbose)
193+
paths = ConflictBasedSearch.plan(grid, start_and_goals, single_agent_planner, verbose)
193194

194195
runtime = time.time() - start_time
195196
print(f"\nPlanning took: {runtime:.5f} seconds")
196197

197198
if verbose:
198199
print(f"Paths:")
199-
for path in paths:
200+
for path in paths.values():
200201
print(f"{path}\n")
201202

202203
if not show_animation:

PathPlanning/TimeBasedPathPlanning/Plotting.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from PathPlanning.TimeBasedPathPlanning.BaseClasses import StartAndGoal
99
from PathPlanning.TimeBasedPathPlanning.Node import NodePath
10+
from PathPlanning.TimeBasedPathPlanning.ConstraintTree import AgentId
1011

1112
'''
1213
Plot a single agent path.
@@ -50,7 +51,7 @@ def PlotNodePath(grid: Grid, start: Position, goal: Position, path: NodePath):
5051
'''
5152
Plot a series of agent paths.
5253
'''
53-
def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: list[NodePath]):
54+
def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: dict[AgentId, NodePath]):
5455
fig = plt.figure(figsize=(10, 7))
5556

5657
ax = fig.add_subplot(
@@ -64,19 +65,21 @@ def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: list[N
6465
ax.set_yticks(np.arange(0, grid.grid_size[1], 1))
6566

6667
# Plot start and goal positions for each agent
67-
colors = [] # generated randomly in loop
68+
colors = {} # generated randomly in loop. Maps agent id to color
6869
markers = ['D', 's', '^', 'o', 'p'] # Different markers for visual distinction
6970

7071
# Create plots for start and goal positions
7172
start_and_goal_plots = []
72-
for i, path in enumerate(paths):
73-
marker_idx = i % len(markers)
74-
agent_id = start_and_goals[i].agent_id
75-
start = start_and_goals[i].start
76-
goal = start_and_goals[i].goal
73+
for agent_id, path in paths.items():
74+
marker_idx = agent_id % len(markers)
75+
start_and_goal = next((elem for elem in start_and_goals if elem.agent_id == agent_id), None)
76+
if not start_and_goal:
77+
raise RuntimeError(f"Failed to get start and goal for agent {agent_id}")
78+
start = start_and_goal.start
79+
goal = start_and_goal.goal
7780

7881
color = np.random.rand(3,)
79-
colors.append(color)
82+
colors[agent_id] = color
8083
sg_plot, = ax.plot([], [], markers[marker_idx], c=color, ms=15,
8184
label=f"Agent {agent_id} Start/Goal")
8285
sg_plot.set_data([start.x, goal.x], [start.y, goal.y])
@@ -86,12 +89,11 @@ def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: list[N
8689
(obs_points,) = ax.plot([], [], "ro", ms=15, label="Obstacles")
8790

8891
# Create plots for each agent's path
89-
path_plots = []
90-
for i, path in enumerate(paths):
91-
agent_id = start_and_goals[i].agent_id
92-
path_plot, = ax.plot([], [], "o", c=colors[i], ms=10,
92+
path_plots = {}
93+
for agent_id, path in paths.items():
94+
path_plot, = ax.plot([], [], "o", c=colors[agent_id], ms=10,
9395
label=f"Agent {agent_id} Path")
94-
path_plots.append(path_plot)
96+
path_plots[agent_id] = path_plot
9597

9698
ax.legend(bbox_to_anchor=(1.05, 1))
9799

@@ -103,32 +105,32 @@ def PlotNodePaths(grid: Grid, start_and_goals: list[StartAndGoal], paths: list[N
103105
)
104106

105107
# Find the maximum time across all paths
106-
max_time = max(path.goal_reached_time() for path in paths)
108+
max_time = max(path.goal_reached_time() for path in paths.values())
107109

108110
# Animation loop
109-
for i in range(0, max_time + 1):
111+
for t in range(0, max_time + 1):
110112
# Update obstacle positions
111-
obs_positions = grid.get_obstacle_positions_at_time(i)
113+
obs_positions = grid.get_obstacle_positions_at_time(t)
112114
obs_points.set_data(obs_positions[0], obs_positions[1])
113115

114116
# Update each agent's position
115-
for (j, path) in enumerate(paths):
117+
for agent_id, path in paths.items():
116118
path_positions = []
117-
if i <= path.goal_reached_time():
118-
res = path.get_position(i)
119+
if t <= path.goal_reached_time():
120+
res = path.get_position(t)
119121
if not res:
120-
print(path)
121-
print(i)
122-
path_position = path.get_position(i)
122+
print(f"Error getting position for agent {agent_id} at time {t}")
123+
print(t)
124+
path_position = path.get_position(t)
123125
if not path_position:
124-
raise Exception(f"Path position not found for time {i}.")
126+
raise Exception(f"Path position not found for time {t}.")
125127

126128
# Verify position is valid
127129
assert not path_position in obs_positions
128130
assert not path_position in path_positions
129131
path_positions.append(path_position)
130132

131-
path_plots[j].set_data([path_position.x], [path_position.y])
133+
path_plots[agent_id].set_data([path_position.x], [path_position.y])
132134

133135
plt.pause(0.2)
134136

PathPlanning/TimeBasedPathPlanning/PriorityBasedPlanner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from PathPlanning.TimeBasedPathPlanning.BaseClasses import SingleAgentPlanner
1919
from PathPlanning.TimeBasedPathPlanning.SafeInterval import SafeIntervalPathPlanner
2020
from PathPlanning.TimeBasedPathPlanning.Plotting import PlotNodePaths
21+
from PathPlanning.TimeBasedPathPlanning.ConstraintTree import AgentId
2122
import time
2223

2324
class PriorityBasedPlanner(MultiAgentPlanner):
2425

2526
@staticmethod
26-
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> tuple[list[StartAndGoal], list[NodePath]]:
27+
def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_class: SingleAgentPlanner, verbose: bool = False) -> dict[AgentId, NodePath]:
2728
"""
2829
Generate a path from the start to the goal for each agent in the `start_and_goals` list.
2930
Returns the re-ordered StartAndGoal combinations, and a list of path plans. The order of the plans
@@ -40,7 +41,7 @@ def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_c
4041
key=lambda item: item.distance_start_to_goal(),
4142
reverse=True)
4243

43-
paths = []
44+
paths = {}
4445
for start_and_goal in start_and_goals:
4546
if verbose:
4647
print(f"\nPlanning for agent: {start_and_goal}" )
@@ -54,9 +55,9 @@ def plan(grid: Grid, start_and_goals: list[StartAndGoal], single_agent_planner_c
5455

5556
agent_index = start_and_goal.agent_id
5657
grid.reserve_path(path, agent_index)
57-
paths.append(path)
58+
paths[start_and_goal.agent_id] =path
5859

59-
return (start_and_goals, paths)
60+
return paths
6061

6162
verbose = False
6263
show_animation = True
@@ -76,7 +77,7 @@ def main():
7677
)
7778

7879
start_time = time.time()
79-
start_and_goals, paths = PriorityBasedPlanner.plan(grid, start_and_goals, SafeIntervalPathPlanner, verbose)
80+
paths = PriorityBasedPlanner.plan(grid, start_and_goals, SafeIntervalPathPlanner, verbose)
8081

8182
runtime = time.time() - start_time
8283
print(f"\nPlanning took: {runtime:.5f} seconds")

0 commit comments

Comments
 (0)