1+ """
2+ astar_path_planner.py
3+
4+ Author: Shantanu Parab
5+ """
6+
17import numpy as np
28import matplotlib .pyplot as plt
39import heapq
410import matplotlib .animation as anm
11+ import numpy as np
12+ import sys
13+ from pathlib import Path
14+ from matplotlib .colors import ListedColormap
15+
16+ abs_dir_path = str (Path (__file__ ).absolute ().parent )
17+ relative_path = "/../../../components/"
18+ relative_simulations = "/../../../simulations/"
19+
20+
21+ sys .path .append (abs_dir_path + relative_path + "visualization" )
22+ sys .path .append (abs_dir_path + relative_path + "state" )
23+ sys .path .append (abs_dir_path + relative_path + "obstacle" )
24+ sys .path .append (abs_dir_path + relative_path + "plan/astar" )
25+ sys .path .append (abs_dir_path + relative_path + "mapping/grid" )
26+
27+
28+
29+
30+ from state import State
31+ from obstacle import Obstacle
32+ from obstacle_list import ObstacleList
33+ from binary_occupancy_grid import BinaryOccupancyGrid
34+ from min_max import MinMax
35+ import json
36+
37+
38+
539
640
741class AStarPathPlanner :
8- def __init__ (self , start , goal , obstacle_parameters , resolution = 0.1 , weight = 1.0 , obstacle_clearance = 0.0 , robot_clearance = 0.0 , visualize = False , x_lim = None , y_lim = None ):
42+ def __init__ (self , start , goal , map_file , weight = 1.0 , x_lim = None , y_lim = None , path_filename = None , gif_name = None ):
943 """
1044 Initialize the A* planner.
1145 Args:
@@ -20,135 +54,44 @@ def __init__(self, start, goal, obstacle_parameters, resolution=0.1, weight=1.0,
2054 """
2155 self .start = start
2256 self .goal = goal
23- self .obstacle_parameters = obstacle_parameters
24- self .resolution = resolution
2557 self .weight = weight
26- self .visualize = visualize
27- self .x_lim = x_lim
28- self .y_lim = y_lim
29- self .obstacle_clearance = obstacle_clearance
30- self .robot_clearance = robot_clearance
31- self .clearance = obstacle_clearance + robot_clearance
32- self .grid , self .x_range , self .y_range = self .create_grid ()
33- self .mark_obstacles ()
34-
3558 self .explored_nodes = []
36-
37-
38- if visualize :
39- plt .figure (figsize = (10 , 8 ))
40- plt .imshow (self .grid , extent = [self .x_range [0 ], self .x_range [- 1 ], self .y_range [0 ], self .y_range [- 1 ]],
41- origin = 'lower' , cmap = 'Greys' )
42- plt .plot (start [0 ], start [1 ], 'go' , label = "Start" ) # Start point
43- plt .plot (goal [0 ], goal [1 ], 'ro' , label = "Goal" ) # Goal point
44- plt .legend ()
45- plt .show ()
46-
47-
48- def create_grid (self ):
49- """Create a grid based on the specified or derived limits."""
50- if self .x_lim and self .y_lim :
51- x_min , x_max = self .x_lim .min_value (), self .x_lim .max_value ()
52- y_min , y_max = self .y_lim .min_value (), self .y_lim .max_value ()
53-
54- x_range = np .arange (x_min , x_max , self .resolution )
55- y_range = np .arange (y_min , y_max , self .resolution )
56- # print("x_range: ", x_range)
57- # print("y_range: ", y_range)
58- grid = np .zeros ((len (y_range ), len (x_range ))) # Initialize grid as free space
59- # print("grid element: ", grid[-10][0])
60- return grid , x_range , y_range
61-
62- def mark_obstacles (self ):
63- """Mark obstacles and their clearance on the grid, considering rotation (yaw)."""
64- for obs in self .obstacle_parameters :
65- # Get obstacle parameters
66- x_c = obs ["x_m" ]
67- y_c = obs ["y_m" ]
68- yaw = obs ["yaw_rad" ]
69- length = obs ["length_m" ]
70- width = obs ["width_m" ]
71-
72- # Calculate the clearance dimensions
73- clearance_length = length + self .clearance
74- clearance_width = width + self .clearance
75-
76- # Define corners for the clearance area
77- clearance_corners = np .array ([
78- [- clearance_length , - clearance_width ],
79- [- clearance_length , clearance_width ],
80- [clearance_length , clearance_width ],
81- [clearance_length , - clearance_width ]
82- ])
83-
84- # Define corners for the actual obstacle
85- obstacle_corners = np .array ([
86- [- length , - width ],
87- [- length , width ],
88- [length , width ],
89- [length , - width ]
90- ])
91-
92- # Apply rotation to both obstacle and clearance corners
93- rotation_matrix = np .array ([
94- [np .cos (yaw ), - np .sin (yaw )],
95- [np .sin (yaw ), np .cos (yaw )]
96- ])
97- rotated_clearance_corners = np .dot (clearance_corners , rotation_matrix .T ) + np .array ([x_c , y_c ])
98- rotated_obstacle_corners = np .dot (obstacle_corners , rotation_matrix .T ) + np .array ([x_c , y_c ])
99-
100- # Mark the clearance area
101- self ._mark_area (rotated_clearance_corners , value = 0.5 ) # 0.5 for clearance
102-
103- # Mark the actual obstacle area
104- self ._mark_area (rotated_obstacle_corners , value = 1.0 ) # 1.0 for obstacles
105-
106- def _point_in_polygon (self , x , y , corners ):
59+ self .grid = self .load_grid_from_file (map_file )
60+ x_min , x_max = x_lim .min_value (), x_lim .max_value ()
61+ y_min , y_max = y_lim .min_value (), y_lim .max_value ()
62+ self .resolution = (x_max - x_min ) / self .grid .shape [1 ] # Width of each cell
63+ self .x_range = np .arange (x_min , x_max , self .resolution )
64+ self .y_range = np .arange (y_min , y_max , self .resolution )
65+ self .path = []
66+ self .path_filename = path_filename
67+ self .search ()
68+ self .visualize_search (gif_name )
69+
70+ def load_grid_from_file (self , file_path ):
10771 """
108- Check if a point (x, y) is inside a polygon defined by corners .
72+ Load a grid from a file and convert it to a numpy array .
10973 Args:
110- x: X-coordinate of the point.
111- y: Y-coordinate of the point.
112- corners: Array of polygon corners in global coordinates.
74+ file_path: Path to the file containing the grid data.
11375 Returns:
114- True if the point is inside the polygon, False otherwise.
115- """
116- n = len (corners )
117- inside = False
118- px , py = x , y
119- for i in range (n ):
120- x1 , y1 = corners [i ]
121- x2 , y2 = corners [(i + 1 ) % n ]
122- if ((y1 > py ) != (y2 > py )) and \
123- (px < (x2 - x1 ) * (py - y1 ) / (y2 - y1 + 1e-6 ) + x1 ):
124- inside = not inside
125- return inside
126-
127-
128- def _mark_area (self , corners , value ):
129- """
130- Mark a rectangular area on the grid based on the given rotated corners.
131- Args:
132- corners: The rotated corners of the area in global coordinates.
133- value: The value to mark in the grid (e.g., 0.5 for clearance, 1.0 for obstacles).
76+ grid: A numpy array representing the grid.
13477 """
135- # Get the bounding box of the corners
136- x_min = max (0 , int ((min (corners [:, 0 ]) - self .x_range [0 ]) / self .resolution ))
137- x_max = min (self .grid .shape [1 ], int ((max (corners [:, 0 ]) - self .x_range [0 ]) / self .resolution ))
138- y_min = max (0 , int ((min (corners [:, 1 ]) - self .y_range [0 ]) / self .resolution ))
139- y_max = min (self .grid .shape [0 ], int ((max (corners [:, 1 ]) - self .y_range [0 ]) / self .resolution ))
140-
141- # Iterate through the grid cells in the bounding box
142- for x in range (x_min , x_max ):
143- for y in range (y_min , y_max ):
144- # Get the center of the current cell
145- cell_x = self .x_range [0 ] + x * self .resolution + self .resolution / 2
146- cell_y = self .y_range [0 ] + y * self .resolution + self .resolution / 2
147-
148- # Check if the cell center is inside the rotated polygon
149- if self ._point_in_polygon (cell_x , cell_y , corners ):
150- self .grid [y , x ] = max (self .grid [y , x ], value ) # Mark the cell
78+ file_extension = Path (file_path ).suffix
79+
80+ if file_extension == '.npy' :
81+ grid = np .load (file_path )
82+ elif file_extension == '.png' :
83+ grid = plt .imread (file_path )
84+ if grid .ndim == 3 : # If the image has color channels, convert to grayscale
85+ grid = np .mean (grid , axis = 2 )
86+ grid = (grid > 0.5 ).astype (int ) # Binarize the image
87+ elif file_extension == '.json' :
88+ with open (file_path , 'r' ) as f :
89+ grid_data = json .load (f )
90+ grid = np .array (grid_data )
91+ else :
92+ raise ValueError (f"Unsupported file format: { file_extension } " )
15193
94+ return grid
15295
15396 def heuristic (self , a , b ):
15497 return self .weight * (abs (a [0 ] - b [0 ]) + abs (a [1 ] - b [1 ]))
@@ -159,15 +102,15 @@ def is_valid(self, x, y):
159102 Converts world coordinates to grid indices, accounting for negative min values.
160103 """
161104 # Check if indices are within bounds and not an obstacle
162- return (0 <= x < self .grid .shape [1 ] and
163- 0 <= y < self .grid .shape [0 ] and
105+ return (0 <= x < self .grid .shape [1 ] and
106+ 0 <= y < self .grid .shape [0 ] and
164107 self .grid [y , x ] == 0 )
165108
166109 def search (self ):
167- start_idx = (int ((self .start [0 ] - self .x_range [0 ]) / self .resolution ),
168- int ((self .start [1 ] - self .y_range [0 ]) / self .resolution ))
169- goal_idx = (int ((self .goal [0 ] - self .x_range [0 ]) / self .resolution ),
170- int ((self .goal [1 ] - self .y_range [0 ]) / self .resolution ))
110+ start_idx = (int ((self .start [0 ] - self .x_range [0 ]) / self .resolution ),
111+ int ((self .start [1 ] - self .y_range [0 ]) / self .resolution ))
112+ goal_idx = (int ((self .goal [0 ] - self .x_range [0 ]) / self .resolution ),
113+ int ((self .goal [1 ] - self .y_range [0 ]) / self .resolution ))
171114
172115 open_list = []
173116 heapq .heappush (open_list , (0 , start_idx ))
@@ -180,7 +123,10 @@ def search(self):
180123 self .explored_nodes .append (current )
181124 if current == goal_idx :
182125 print (f"Goal found at: { current } " )
183- return self .reconstruct_path (came_from , start_idx , goal_idx )
126+ self .path = self .reconstruct_path (came_from , start_idx , goal_idx )
127+ sparse_path = self .make_sparse_path (self .path )
128+ self .save_path (sparse_path , self .path_filename )
129+ return
184130
185131 for dx , dy in [(- 1 , 0 ), (1 , 0 ), (0 , - 1 ), (0 , 1 ),(1 , 1 ), (- 1 , - 1 ), (1 , - 1 ), (- 1 , 1 )]:
186132 neighbor = (current [0 ] + dx , current [1 ] + dy )
@@ -192,7 +138,6 @@ def search(self):
192138 priority = new_cost + self .heuristic (neighbor , goal_idx )
193139 heapq .heappush (open_list , (priority , neighbor ))
194140 came_from [neighbor ] = current
195-
196141
197142 return []
198143
@@ -224,8 +169,8 @@ def _grid_to_world(self, grid_node):
224169 (world_x, world_y): Corresponding world coordinates.
225170 """
226171 grid_x , grid_y = grid_node
227- world_x = self .x_range [0 ] + grid_x * self .resolution
228- world_y = self .y_range [0 ] + grid_y * self .resolution
172+ world_x = self .x_range [0 ] + grid_x * self .resolution
173+ world_y = self .y_range [0 ] + grid_y * self .resolution
229174 return (world_x , world_y )
230175
231176 def make_sparse_path (self , path , num_points = 20 ):
@@ -245,8 +190,18 @@ def make_sparse_path(self, path, num_points=20):
245190 indices = np .linspace (0 , len (path ) - 1 , num_points , dtype = int )
246191 sparse_path = [self ._grid_to_world (path [i ]) for i in indices ]
247192 return sparse_path
248-
249- def visualize_search (self , path , gif_name = None ):
193+
194+ def save_path (self , path , filename ):
195+
196+ """Save path to a json file."""
197+ if not Path (filename ).exists ():
198+ Path (filename ).touch ()
199+ path = [node for node in path ]
200+ with open (filename , "w" ) as f :
201+ json .dump (path , f )
202+
203+
204+ def visualize_search (self , gif_name = None ):
250205 print (f"Exploring { len (self .explored_nodes )} nodes." )
251206 if not self .explored_nodes :
252207 print ("Error: No explored nodes. Ensure search() is executed before visualize_search()." )
@@ -259,21 +214,29 @@ def visualize_search(self, path, gif_name=None):
259214 axes .set_xlabel ("X [m]" , fontsize = 15 )
260215 axes .set_ylabel ("Y [m]" , fontsize = 15 )
261216
217+
262218 self .anime = anm .FuncAnimation (
263219 figure ,
264220 self .update_frame ,
265- fargs = (axes , path ),
266- frames = len (self .explored_nodes ) + len (path ), # Include frames for the path
221+ fargs = (axes , self . path ),
222+ frames = len (self .explored_nodes ) + len (self . path ), # Include frames for the path
267223 interval = 50 ,
268224 repeat = False ,
269225 )
270226
271- if gif_name :
227+ if gif_name is not None :
272228 try :
229+ print ("Saving animation..." )
273230 self .anime .save (gif_name , writer = "pillow" )
231+ print ("Animation saved successfully." )
274232 except Exception as e :
275233 print (f"Error saving animation: { e } " )
276- plt .show ()
234+ else :
235+ plt .show ()
236+
237+ # clear existing plot and close existing figure
238+ plt .clf ()
239+ plt .close ()
277240
278241
279242 def update_frame (self , i , axes , path ):
@@ -290,7 +253,7 @@ def update_frame(self, i, axes, path):
290253 node = self .explored_nodes [i ]
291254 grid_x = int (node [0 ])
292255 grid_y = int (node [1 ])
293- self .grid [grid_y , grid_x ] = 0.5 # Set a value to represent explored nodes
256+ self .grid [grid_y , grid_x ] = 0.25 # Set a value to represent explored nodes
294257
295258 # Path reconstruction phase
296259 else :
@@ -299,14 +262,48 @@ def update_frame(self, i, axes, path):
299262 node = path [path_index ]
300263 grid_x = int (node [0 ])
301264 grid_y = int (node [1 ])
302- self .grid [grid_y , grid_x ] = 0.75 # Set a value to represent the path
265+ self .grid [grid_y , grid_x ] = 0.5 # Set a value to represent the path
303266
304267 # Clear the axes and redraw the updated grid
305268 axes .clear ()
269+
270+ # Define RGB colors for each grid value
271+ # Colors in the format [R, G, B], where values are in the range [0, 1]
272+ colors = [
273+ [1.0 , 1.0 , 1.0 ], # Free space (white)
274+ [0.4 , 0.8 , 1.0 ], # Explored nodes (light blue)
275+ [0.0 , 1.0 , 0.0 ], # Path (green)
276+ [0.5 , 0.5 , 0.5 ], # Clearance space (yellow-orange)
277+ [0.0 , 0.0 , 0.0 ], # Obstacles (red)
278+ ]
279+
280+ # Create a colormap
281+ custom_cmap = ListedColormap (colors )
282+
283+
306284 axes .imshow (self .grid , extent = [self .x_range [0 ], self .x_range [- 1 ], self .y_range [0 ], self .y_range [- 1 ]],
307- origin = 'lower' , cmap = 'coolwarm' , alpha = 0.8 )
285+ origin = 'lower' , cmap = custom_cmap , alpha = 0.8 )
308286 axes .plot (self .start [0 ], self .start [1 ], 'go' , label = "Start" )
309287 axes .plot (self .goal [0 ], self .goal [1 ], 'ro' , label = "Goal" )
310288 axes .legend ()
311289
312290
291+ if __name__ == "__main__" :
292+
293+ # The path to the map file where the planner will search for a path
294+ map_file = "map.json"
295+ # Define the path file to save the path that is generated by the planner
296+ path_file = "path.json"
297+ # Visualize the search process and save the gif
298+ gif_path = "astar_search.gif"
299+
300+ x_lim , y_lim = MinMax (- 5 , 55 ), MinMax (- 20 , 25 )
301+
302+ # Define the start and goal positions
303+ start = (0 , 0 )
304+ goal = (50 , - 10 )
305+
306+ # Create the A* planner
307+ planner = AStarPathPlanner (start , goal , map_file , weight = 5.0 , x_lim = x_lim , y_lim = y_lim , path_filename = path_file , gif_name = gif_path )
308+
309+
0 commit comments