Skip to content

Commit 22a6b47

Browse files
Astar Planner Implementation
1 parent 52caef8 commit 22a6b47

1 file changed

Lines changed: 139 additions & 142 deletions

File tree

Lines changed: 139 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,45 @@
1+
"""
2+
astar_path_planner.py
3+
4+
Author: Shantanu Parab
5+
"""
6+
17
import numpy as np
28
import matplotlib.pyplot as plt
39
import heapq
410
import matplotlib.animation as anm
11+
import numpy as np
12+
import sys
13+
from pathlib import Path
14+
from matplotlib.colors import ListedColormap
15+
16+
abs_dir_path = str(Path(__file__).absolute().parent)
17+
relative_path = "/../../../components/"
18+
relative_simulations = "/../../../simulations/"
19+
20+
21+
sys.path.append(abs_dir_path + relative_path + "visualization")
22+
sys.path.append(abs_dir_path + relative_path + "state")
23+
sys.path.append(abs_dir_path + relative_path + "obstacle")
24+
sys.path.append(abs_dir_path + relative_path + "plan/astar")
25+
sys.path.append(abs_dir_path + relative_path + "mapping/grid")
26+
27+
28+
29+
30+
from state import State
31+
from obstacle import Obstacle
32+
from obstacle_list import ObstacleList
33+
from binary_occupancy_grid import BinaryOccupancyGrid
34+
from min_max import MinMax
35+
import json
36+
37+
38+
539

640

741
class AStarPathPlanner:
8-
def __init__(self, start, goal, obstacle_parameters, resolution=0.1, weight=1.0, obstacle_clearance=0.0, robot_clearance=0.0, visualize=False, x_lim=None, y_lim=None):
42+
def __init__(self, start, goal, map_file, weight=1.0, x_lim=None, y_lim=None, path_filename=None, gif_name=None):
943
"""
1044
Initialize the A* planner.
1145
Args:
@@ -20,135 +54,44 @@ def __init__(self, start, goal, obstacle_parameters, resolution=0.1, weight=1.0,
2054
"""
2155
self.start = start
2256
self.goal = goal
23-
self.obstacle_parameters = obstacle_parameters
24-
self.resolution = resolution
2557
self.weight = weight
26-
self.visualize = visualize
27-
self.x_lim = x_lim
28-
self.y_lim = y_lim
29-
self.obstacle_clearance = obstacle_clearance
30-
self.robot_clearance = robot_clearance
31-
self.clearance = obstacle_clearance + robot_clearance
32-
self.grid, self.x_range, self.y_range = self.create_grid()
33-
self.mark_obstacles()
34-
3558
self.explored_nodes = []
36-
37-
38-
if visualize:
39-
plt.figure(figsize=(10, 8))
40-
plt.imshow(self.grid, extent=[self.x_range[0], self.x_range[-1], self.y_range[0], self.y_range[-1]],
41-
origin='lower', cmap='Greys')
42-
plt.plot(start[0], start[1], 'go', label="Start") # Start point
43-
plt.plot(goal[0], goal[1], 'ro', label="Goal") # Goal point
44-
plt.legend()
45-
plt.show()
46-
47-
48-
def create_grid(self):
49-
"""Create a grid based on the specified or derived limits."""
50-
if self.x_lim and self.y_lim:
51-
x_min, x_max = self.x_lim.min_value(), self.x_lim.max_value()
52-
y_min, y_max = self.y_lim.min_value(), self.y_lim.max_value()
53-
54-
x_range = np.arange(x_min, x_max, self.resolution)
55-
y_range = np.arange(y_min, y_max, self.resolution)
56-
# print("x_range: ", x_range)
57-
# print("y_range: ", y_range)
58-
grid = np.zeros((len(y_range), len(x_range))) # Initialize grid as free space
59-
# print("grid element: ", grid[-10][0])
60-
return grid, x_range, y_range
61-
62-
def mark_obstacles(self):
63-
"""Mark obstacles and their clearance on the grid, considering rotation (yaw)."""
64-
for obs in self.obstacle_parameters:
65-
# Get obstacle parameters
66-
x_c = obs["x_m"]
67-
y_c = obs["y_m"]
68-
yaw = obs["yaw_rad"]
69-
length = obs["length_m"]
70-
width = obs["width_m"]
71-
72-
# Calculate the clearance dimensions
73-
clearance_length = length + self.clearance
74-
clearance_width = width + self.clearance
75-
76-
# Define corners for the clearance area
77-
clearance_corners = np.array([
78-
[-clearance_length, -clearance_width],
79-
[-clearance_length, clearance_width],
80-
[clearance_length, clearance_width],
81-
[clearance_length, -clearance_width]
82-
])
83-
84-
# Define corners for the actual obstacle
85-
obstacle_corners = np.array([
86-
[-length, -width],
87-
[-length, width],
88-
[length, width],
89-
[length, -width]
90-
])
91-
92-
# Apply rotation to both obstacle and clearance corners
93-
rotation_matrix = np.array([
94-
[np.cos(yaw), -np.sin(yaw)],
95-
[np.sin(yaw), np.cos(yaw)]
96-
])
97-
rotated_clearance_corners = np.dot(clearance_corners, rotation_matrix.T) + np.array([x_c, y_c])
98-
rotated_obstacle_corners = np.dot(obstacle_corners, rotation_matrix.T) + np.array([x_c, y_c])
99-
100-
# Mark the clearance area
101-
self._mark_area(rotated_clearance_corners, value=0.5) # 0.5 for clearance
102-
103-
# Mark the actual obstacle area
104-
self._mark_area(rotated_obstacle_corners, value=1.0) # 1.0 for obstacles
105-
106-
def _point_in_polygon(self, x, y, corners):
59+
self.grid = self.load_grid_from_file(map_file)
60+
x_min, x_max = x_lim.min_value(), x_lim.max_value()
61+
y_min, y_max = y_lim.min_value(), y_lim.max_value()
62+
self.resolution = (x_max - x_min) / self.grid.shape[1] # Width of each cell
63+
self.x_range = np.arange(x_min, x_max, self.resolution)
64+
self.y_range = np.arange(y_min, y_max, self.resolution)
65+
self.path = []
66+
self.path_filename = path_filename
67+
self.search()
68+
self.visualize_search(gif_name)
69+
70+
def load_grid_from_file(self, file_path):
10771
"""
108-
Check if a point (x, y) is inside a polygon defined by corners.
72+
Load a grid from a file and convert it to a numpy array.
10973
Args:
110-
x: X-coordinate of the point.
111-
y: Y-coordinate of the point.
112-
corners: Array of polygon corners in global coordinates.
74+
file_path: Path to the file containing the grid data.
11375
Returns:
114-
True if the point is inside the polygon, False otherwise.
115-
"""
116-
n = len(corners)
117-
inside = False
118-
px, py = x, y
119-
for i in range(n):
120-
x1, y1 = corners[i]
121-
x2, y2 = corners[(i + 1) % n]
122-
if ((y1 > py) != (y2 > py)) and \
123-
(px < (x2 - x1) * (py - y1) / (y2 - y1 + 1e-6) + x1):
124-
inside = not inside
125-
return inside
126-
127-
128-
def _mark_area(self, corners, value):
129-
"""
130-
Mark a rectangular area on the grid based on the given rotated corners.
131-
Args:
132-
corners: The rotated corners of the area in global coordinates.
133-
value: The value to mark in the grid (e.g., 0.5 for clearance, 1.0 for obstacles).
76+
grid: A numpy array representing the grid.
13477
"""
135-
# Get the bounding box of the corners
136-
x_min = max(0, int((min(corners[:, 0]) - self.x_range[0]) / self.resolution))
137-
x_max = min(self.grid.shape[1], int((max(corners[:, 0]) - self.x_range[0]) / self.resolution))
138-
y_min = max(0, int((min(corners[:, 1]) - self.y_range[0]) / self.resolution))
139-
y_max = min(self.grid.shape[0], int((max(corners[:, 1]) - self.y_range[0]) / self.resolution))
140-
141-
# Iterate through the grid cells in the bounding box
142-
for x in range(x_min, x_max):
143-
for y in range(y_min, y_max):
144-
# Get the center of the current cell
145-
cell_x = self.x_range[0] + x * self.resolution + self.resolution / 2
146-
cell_y = self.y_range[0] + y * self.resolution + self.resolution / 2
147-
148-
# Check if the cell center is inside the rotated polygon
149-
if self._point_in_polygon(cell_x, cell_y, corners):
150-
self.grid[y, x] = max(self.grid[y, x], value) # Mark the cell
78+
file_extension = Path(file_path).suffix
79+
80+
if file_extension == '.npy':
81+
grid = np.load(file_path)
82+
elif file_extension == '.png':
83+
grid = plt.imread(file_path)
84+
if grid.ndim == 3: # If the image has color channels, convert to grayscale
85+
grid = np.mean(grid, axis=2)
86+
grid = (grid > 0.5).astype(int) # Binarize the image
87+
elif file_extension == '.json':
88+
with open(file_path, 'r') as f:
89+
grid_data = json.load(f)
90+
grid = np.array(grid_data)
91+
else:
92+
raise ValueError(f"Unsupported file format: {file_extension}")
15193

94+
return grid
15295

15396
def heuristic(self, a, b):
15497
return self.weight * (abs(a[0] - b[0]) + abs(a[1] - b[1]))
@@ -159,15 +102,15 @@ def is_valid(self, x, y):
159102
Converts world coordinates to grid indices, accounting for negative min values.
160103
"""
161104
# Check if indices are within bounds and not an obstacle
162-
return (0 <= x < self.grid.shape[1] and
163-
0 <= y < self.grid.shape[0] and
105+
return (0 <= x < self.grid.shape[1] and
106+
0 <= y < self.grid.shape[0] and
164107
self.grid[y, x] == 0)
165108

166109
def search(self):
167-
start_idx = (int((self.start[0] - self.x_range[0]) / self.resolution),
168-
int((self.start[1] - self.y_range[0]) / self.resolution))
169-
goal_idx = (int((self.goal[0] - self.x_range[0]) / self.resolution),
170-
int((self.goal[1] - self.y_range[0]) / self.resolution))
110+
start_idx = (int((self.start[0] - self.x_range[0]) /self.resolution),
111+
int((self.start[1] - self.y_range[0]) /self.resolution))
112+
goal_idx = (int((self.goal[0] - self.x_range[0]) /self.resolution),
113+
int((self.goal[1] - self.y_range[0]) /self.resolution))
171114

172115
open_list = []
173116
heapq.heappush(open_list, (0, start_idx))
@@ -180,7 +123,10 @@ def search(self):
180123
self.explored_nodes.append(current)
181124
if current == goal_idx:
182125
print(f"Goal found at: {current}")
183-
return self.reconstruct_path(came_from, start_idx, goal_idx)
126+
self.path = self.reconstruct_path(came_from, start_idx, goal_idx)
127+
sparse_path = self.make_sparse_path(self.path)
128+
self.save_path(sparse_path, self.path_filename)
129+
return
184130

185131
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1),(1, 1), (-1, -1), (1, -1), (-1, 1)]:
186132
neighbor = (current[0] + dx, current[1] + dy)
@@ -192,7 +138,6 @@ def search(self):
192138
priority = new_cost + self.heuristic(neighbor, goal_idx)
193139
heapq.heappush(open_list, (priority, neighbor))
194140
came_from[neighbor] = current
195-
196141

197142
return []
198143

@@ -224,8 +169,8 @@ def _grid_to_world(self, grid_node):
224169
(world_x, world_y): Corresponding world coordinates.
225170
"""
226171
grid_x, grid_y = grid_node
227-
world_x = self.x_range[0] + grid_x * self.resolution
228-
world_y = self.y_range[0] + grid_y * self.resolution
172+
world_x = self.x_range[0] + grid_x *self.resolution
173+
world_y = self.y_range[0] + grid_y *self.resolution
229174
return (world_x, world_y)
230175

231176
def make_sparse_path(self, path, num_points=20):
@@ -245,8 +190,18 @@ def make_sparse_path(self, path, num_points=20):
245190
indices = np.linspace(0, len(path) - 1, num_points, dtype=int)
246191
sparse_path = [self._grid_to_world(path[i]) for i in indices]
247192
return sparse_path
248-
249-
def visualize_search(self, path, gif_name=None):
193+
194+
def save_path(self, path, filename):
195+
196+
"""Save path to a json file."""
197+
if not Path(filename).exists():
198+
Path(filename).touch()
199+
path = [node for node in path]
200+
with open(filename, "w") as f:
201+
json.dump(path, f)
202+
203+
204+
def visualize_search(self, gif_name=None):
250205
print(f"Exploring {len(self.explored_nodes)} nodes.")
251206
if not self.explored_nodes:
252207
print("Error: No explored nodes. Ensure search() is executed before visualize_search().")
@@ -259,21 +214,29 @@ def visualize_search(self, path, gif_name=None):
259214
axes.set_xlabel("X [m]", fontsize=15)
260215
axes.set_ylabel("Y [m]", fontsize=15)
261216

217+
262218
self.anime = anm.FuncAnimation(
263219
figure,
264220
self.update_frame,
265-
fargs=(axes, path),
266-
frames=len(self.explored_nodes) + len(path), # Include frames for the path
221+
fargs=(axes, self.path),
222+
frames=len(self.explored_nodes) + len(self.path), # Include frames for the path
267223
interval=50,
268224
repeat=False,
269225
)
270226

271-
if gif_name:
227+
if gif_name is not None:
272228
try:
229+
print("Saving animation...")
273230
self.anime.save(gif_name, writer="pillow")
231+
print("Animation saved successfully.")
274232
except Exception as e:
275233
print(f"Error saving animation: {e}")
276-
plt.show()
234+
else:
235+
plt.show()
236+
237+
# clear existing plot and close existing figure
238+
plt.clf()
239+
plt.close()
277240

278241

279242
def update_frame(self, i, axes, path):
@@ -290,7 +253,7 @@ def update_frame(self, i, axes, path):
290253
node = self.explored_nodes[i]
291254
grid_x = int(node[0])
292255
grid_y = int(node[1])
293-
self.grid[grid_y, grid_x] = 0.5 # Set a value to represent explored nodes
256+
self.grid[grid_y, grid_x] = 0.25 # Set a value to represent explored nodes
294257

295258
# Path reconstruction phase
296259
else:
@@ -299,14 +262,48 @@ def update_frame(self, i, axes, path):
299262
node = path[path_index]
300263
grid_x = int(node[0])
301264
grid_y = int(node[1])
302-
self.grid[grid_y, grid_x] = 0.75 # Set a value to represent the path
265+
self.grid[grid_y, grid_x] = 0.5 # Set a value to represent the path
303266

304267
# Clear the axes and redraw the updated grid
305268
axes.clear()
269+
270+
# Define RGB colors for each grid value
271+
# Colors in the format [R, G, B], where values are in the range [0, 1]
272+
colors = [
273+
[1.0, 1.0, 1.0], # Free space (white)
274+
[0.4, 0.8, 1.0], # Explored nodes (light blue)
275+
[0.0, 1.0, 0.0], # Path (green)
276+
[0.5, 0.5, 0.5], # Clearance space (yellow-orange)
277+
[0.0, 0.0, 0.0], # Obstacles (red)
278+
]
279+
280+
# Create a colormap
281+
custom_cmap = ListedColormap(colors)
282+
283+
306284
axes.imshow(self.grid, extent=[self.x_range[0], self.x_range[-1], self.y_range[0], self.y_range[-1]],
307-
origin='lower', cmap='coolwarm', alpha=0.8)
285+
origin='lower', cmap=custom_cmap, alpha=0.8)
308286
axes.plot(self.start[0], self.start[1], 'go', label="Start")
309287
axes.plot(self.goal[0], self.goal[1], 'ro', label="Goal")
310288
axes.legend()
311289

312290

291+
if __name__ == "__main__":
292+
293+
# The path to the map file where the planner will search for a path
294+
map_file = "map.json"
295+
# Define the path file to save the path that is generated by the planner
296+
path_file = "path.json"
297+
# Visualize the search process and save the gif
298+
gif_path = "astar_search.gif"
299+
300+
x_lim, y_lim = MinMax(-5, 55), MinMax(-20, 25)
301+
302+
# Define the start and goal positions
303+
start = (0, 0)
304+
goal = (50, -10)
305+
306+
# Create the A* planner
307+
planner = AStarPathPlanner(start, goal, map_file, weight=5.0, x_lim=x_lim, y_lim=y_lim, path_filename=path_file, gif_name=gif_path)
308+
309+

0 commit comments

Comments
 (0)