Skip to content

Commit 8502a37

Browse files
committed
rrt & rrt*
1 parent ca24f57 commit 8502a37

12 files changed

Lines changed: 786 additions & 202 deletions

File tree

src/python_motion_planning/common/env/map/grid.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
@update: 2025.9.5
66
"""
77
from itertools import product
8-
from typing import Iterable, Union, Tuple, Callable, List
8+
from typing import Iterable, Union, Tuple, Callable, List, Dict
99
import time
1010

1111
import numpy as np
@@ -235,7 +235,7 @@ def shape(self) -> tuple:
235235
def esdf(self) -> np.ndarray:
236236
return self._esdf
237237

238-
def map_to_world(self, point: tuple) -> tuple:
238+
def map_to_world(self, point: Tuple[int, ...]) -> tuple:
239239
"""
240240
Convert map coordinates to world coordinates.
241241
@@ -250,7 +250,7 @@ def map_to_world(self, point: tuple) -> tuple:
250250

251251
return tuple((x + 0.5) * self.resolution + float(self.bounds[i, 0]) for i, x in enumerate(point))
252252

253-
def world_to_map(self, point: tuple) -> tuple:
253+
def world_to_map(self, point: Tuple[float, ...]) -> tuple:
254254
"""
255255
Convert world coordinates to map coordinates.
256256
@@ -265,7 +265,7 @@ def world_to_map(self, point: tuple) -> tuple:
265265

266266
return tuple(round((x - float(self.bounds[i, 0])) * (1.0 / self.resolution) - 0.5) for i, x in enumerate(point))
267267

268-
def get_distance(self, p1: tuple, p2: tuple) -> float:
268+
def get_distance(self, p1: Tuple[int, int], p2: Tuple[int, int]) -> float:
269269
"""
270270
Get the distance between two points.
271271
@@ -278,7 +278,7 @@ def get_distance(self, p1: tuple, p2: tuple) -> float:
278278
"""
279279
return Geometry.dist(p1, p2, type='Euclidean')
280280

281-
def within_bounds(self, point: tuple) -> bool:
281+
def within_bounds(self, point: Tuple[int, ...]) -> bool:
282282
"""
283283
Check if a point is within the bounds of the grid map.
284284
@@ -300,7 +300,7 @@ def within_bounds(self, point: tuple) -> bool:
300300
return False
301301
return True
302302

303-
def is_expandable(self, point: tuple, src_point: tuple = None) -> bool:
303+
def is_expandable(self, point: Tuple[int, ...], src_point: Tuple[int, ...] = None) -> bool:
304304
"""
305305
Check if a point is expandable.
306306
@@ -363,7 +363,7 @@ def get_neighbors(self,
363363

364364
return filtered_neighbors
365365

366-
def line_of_sight(self, p1: tuple, p2: tuple) -> list:
366+
def line_of_sight(self, p1: Tuple[int, ...], p2: Tuple[int, ...]) -> list:
367367
"""
368368
N-dimensional line of sight (Bresenham's line algorithm)
369369
@@ -414,7 +414,7 @@ def line_of_sight(self, p1: tuple, p2: tuple) -> list:
414414

415415
return result
416416

417-
def in_collision(self, p1: tuple, p2: tuple) -> bool:
417+
def in_collision(self, p1: Tuple[int, ...], p2: Tuple[int, ...]) -> bool:
418418
"""
419419
Check if the line of sight between two points is in collision.
420420
@@ -498,17 +498,17 @@ def inflate_obstacles(self, radius: float = 1.0) -> None:
498498
self.type_map[i, j] = TYPES.INFLATION
499499
self.inflation_radius = radius
500500

501-
def fill_expands(self, expands: List[Node]) -> None:
501+
def fill_expands(self, expands: Dict[Tuple[int, int], Node]) -> None:
502502
"""
503503
Fill the expands in the map.
504504
505505
Args:
506506
expands: List of expands.
507507
"""
508-
for expand in expands:
509-
if self.type_map[expand.current] != TYPES.FREE:
508+
for expand in expands.keys():
509+
if self.type_map[expand] != TYPES.FREE:
510510
continue
511-
self.type_map[expand.current] = TYPES.EXPAND
511+
self.type_map[expand] = TYPES.EXPAND
512512

513513
def update_esdf(self) -> None:
514514
"""
@@ -527,6 +527,30 @@ def update_esdf(self) -> None:
527527
self._esdf = dist_outside.astype(np.float32)
528528
self._esdf[obstacle_mask] = -dist_inside[obstacle_mask]
529529

530+
def path_map_to_world(self, path: List[Tuple[int, int]]) -> List[Tuple[float, float]]:
531+
"""
532+
Convert path from map coordinates to world coordinates
533+
534+
Args:
535+
path: a list of map coordinates
536+
537+
Returns:
538+
path: a list of world coordinates
539+
"""
540+
return [self.map_to_world(p) for p in path]
541+
542+
def path_world_to_map(self, path: List[Tuple[float, float]]) -> List[Tuple[int, int]]:
543+
"""
544+
Convert path from world coordinates to map coordinates
545+
546+
Args:
547+
path: a list of world coordinates
548+
549+
Returns:
550+
path: a list of map coordinates
551+
"""
552+
return [self.world_to_map(p) for p in path]
553+
530554
def _precompute_offsets(self):
531555
# Generate all possible offsets (-1, 0, +1) in each dimension
532556
self._diagonal_offsets = np.array(np.meshgrid(*[[-1, 0, 1]]*self.dim), dtype=self.dtype).T.reshape(-1, self.dim)

src/python_motion_planning/common/visualizer/visualizer.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
@author: Yang Haodong, Wu Maojia
55
@update: 2025.9.20
66
"""
7-
from typing import Union, Dict
7+
from typing import Union, Dict, List, Tuple
88
from collections import namedtuple
99
import time
1010

@@ -16,7 +16,7 @@
1616
import matplotlib.patheffects as path_effects
1717

1818
from python_motion_planning.controller import BaseController
19-
from python_motion_planning.common.env import TYPES, ToySimulator, Grid, CircularRobot
19+
from python_motion_planning.common.env import TYPES, ToySimulator, Grid, CircularRobot, Node
2020

2121
class Visualizer:
2222
def __init__(self, fig_name: str = ""):
@@ -39,6 +39,10 @@ def __init__(self, fig_name: str = ""):
3939
self.cmap = mcolors.ListedColormap([info for info in self.cmap_dict.values()])
4040
self.norm = mcolors.BoundaryNorm([i for i in range(self.cmap.N + 1)], self.cmap.N)
4141
self.grid_map = None
42+
self.dim = None
43+
44+
def __del__(self):
45+
self.close()
4246

4347
def plot_grid_map(self, grid_map: Grid, equal: bool = False, alpha_3d: float = 0.1,
4448
show_esdf: bool = False, alpha_esdf: float = 0.5) -> None:
@@ -52,6 +56,8 @@ def plot_grid_map(self, grid_map: Grid, equal: bool = False, alpha_3d: float = 0
5256
show_esdf: Whether to show esdf.
5357
alpha_esdf: Alpha of esdf.
5458
'''
59+
self.grid_map = grid_map
60+
self.dim = grid_map.dim
5561
if grid_map.dim == 2:
5662
plt.imshow(
5763
np.transpose(grid_map.type_map.array),
@@ -117,7 +123,9 @@ def plot_grid_map(self, grid_map: Grid, equal: bool = False, alpha_3d: float = 0
117123
else:
118124
raise NotImplementedError(f"Grid map with dim={grid_map.dim} not supported.")
119125

120-
def plot_path(self, path: list, style: str = "-", color: str = "#13ae00", label: str = None, linewidth: float = 2, marker: str = None) -> None:
126+
def plot_path(self, path: List[Union[Tuple[int, ...], Tuple[float, ...]]],
127+
style: str = "-", color: str = "#13ae00", label: str = None,
128+
linewidth: float = 2, marker: str = None, map_frame: bool = True) -> None:
121129
'''
122130
Plot path-like information.
123131
The meaning of parameters are similar to matplotlib.pyplot.plot (https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html).
@@ -129,20 +137,80 @@ def plot_path(self, path: list, style: str = "-", color: str = "#13ae00", label:
129137
label: label of path
130138
linewidth: linewidth of path
131139
marker: marker of path
140+
map_frame: whether path is in map frame or not (world frame)
132141
'''
142+
if map_frame:
143+
path = [self.grid_map.map_to_world(point) for point in path]
144+
133145
path = np.array(path)
134-
if len(path.shape) < 2:
135-
return
136-
if path.shape[1] == 2:
137-
plt.plot(path[:, 0], path[:, 1], style, lw=linewidth, color=color, label=label, marker=marker)
138-
elif path.shape[1] == 3:
146+
147+
if self.dim == 2:
148+
self.ax.plot(path[:, 0], path[:, 1], style, lw=linewidth, color=color, label=label, marker=marker)
149+
elif self.dim == 3:
139150
self.ax.plot(path[:, 0], path[:, 1], path[:, 2], style, lw=linewidth, color=color, label=label, marker=marker)
140151
else:
141-
raise ValueError("Path dimension not supported")
152+
raise ValueError("Dimension not supported")
142153

143154
if label:
144155
self.ax.legend()
145156

157+
def plot_expand_tree(self, expand_tree: Dict[Union[Tuple[int, ...], Tuple[float, ...]], Node],
158+
node_color: str = "C5",
159+
edge_color: str = "C6",
160+
node_size: float = 10,
161+
linewidth: float = 1.0,
162+
node_alpha: float = 1.0,
163+
edge_alpha: float = 1.0,
164+
connect_to_parent: bool = True,
165+
map_frame: bool = True) -> None:
166+
"""
167+
Visualize an expand tree (e.g. RRT).
168+
169+
Args:
170+
expand_tree: Dict mapping coordinate tuple -> Node (world frame).
171+
node_color: Color of the nodes.
172+
edge_color: Color of the edges (parent -> child).
173+
node_size: Size of node markers.
174+
linewidth: Line width of edges.
175+
connect_to_parent: Whether to draw parent-child connections.
176+
map_frame: whether path is in map frame or not (world frame)
177+
"""
178+
if self.dim == 2:
179+
for coord, node in expand_tree.items():
180+
current = node.current
181+
if map_frame:
182+
current = self.grid_map.map_to_world(current)
183+
184+
self.ax.scatter(current[0], current[1],
185+
c=node_color, s=node_size, zorder=3, alpha=node_alpha)
186+
if connect_to_parent and node.parent is not None:
187+
parent = node.parent
188+
if map_frame:
189+
parent = self.grid_map.map_to_world(parent)
190+
self.ax.plot([parent[0], current[0]],
191+
[parent[1], current[1]],
192+
color=edge_color, linewidth=linewidth, zorder=2, alpha=edge_alpha)
193+
194+
elif self.dim == 3:
195+
for coord, node in expand_tree.items():
196+
current = node.current
197+
if map_frame:
198+
current = self.grid_map.map_to_world(current)
199+
200+
self.ax.scatter(current[0], current[1], current[2],
201+
c=node_color, s=node_size, zorder=3, alpha=node_alpha)
202+
if connect_to_parent and node.parent is not None:
203+
parent = node.parent
204+
if map_frame:
205+
parent = self.grid_map.map_to_world(parent)
206+
self.ax.plot([parent[0], current[0]],
207+
[parent[1], current[1]],
208+
[parent[2], current[2]],
209+
color=edge_color, linewidth=linewidth, zorder=2, alpha=edge_alpha)
210+
211+
else:
212+
raise ValueError("Dimension must be 2 or 3")
213+
146214
def plot_circular_robot(self, robot: CircularRobot, axis_equal: bool = True) -> None:
147215
patch = plt.Circle(tuple(robot.pos), robot.radius,
148216
color=robot.color, alpha=robot.alpha, fill=robot.fill, linewidth=robot.linewidth, linestyle=robot.linestyle)
@@ -165,7 +233,8 @@ def plot_circular_robot(self, robot: CircularRobot, axis_equal: bool = True) ->
165233
else:
166234
return patch, text
167235

168-
def render_toy_simulator(self, env: ToySimulator, controllers: Dict[str, BaseController], steps: int = 1000, interval: int = 50,
236+
def render_toy_simulator(self, env: ToySimulator, controllers: Dict[str, BaseController],
237+
steps: int = 1000, interval: int = 50,
169238
show_traj: bool = True, traj_kwargs: dict = {"linestyle": '-', "alpha": 0.7, "linewidth": 1.5},
170239
show_env_info: bool = False, rtf_limit: float = 1.0, grid_kwargs: dict = {},
171240
show_pred_traj: bool = True) -> None:
@@ -281,6 +350,9 @@ def show(self):
281350

282351
def legend(self):
283352
plt.legend()
353+
354+
def close(self):
355+
plt.close()
284356

285357

286358

0 commit comments

Comments
 (0)