44@author: Yang Haodong, Wu Maojia
55@update: 2025.9.20
66"""
7- from typing import Union , Dict
7+ from typing import Union , Dict , List , Tuple
88from collections import namedtuple
99import time
1010
1616import matplotlib .patheffects as path_effects
1717
1818from python_motion_planning .controller import BaseController
19- from python_motion_planning .common .env import TYPES , ToySimulator , Grid , CircularRobot
19+ from python_motion_planning .common .env import TYPES , ToySimulator , Grid , CircularRobot , Node
2020
2121class Visualizer :
2222 def __init__ (self , fig_name : str = "" ):
@@ -39,6 +39,10 @@ def __init__(self, fig_name: str = ""):
3939 self .cmap = mcolors .ListedColormap ([info for info in self .cmap_dict .values ()])
4040 self .norm = mcolors .BoundaryNorm ([i for i in range (self .cmap .N + 1 )], self .cmap .N )
4141 self .grid_map = None
42+ self .dim = None
43+
44+ def __del__ (self ):
45+ self .close ()
4246
4347 def plot_grid_map (self , grid_map : Grid , equal : bool = False , alpha_3d : float = 0.1 ,
4448 show_esdf : bool = False , alpha_esdf : float = 0.5 ) -> None :
@@ -52,6 +56,8 @@ def plot_grid_map(self, grid_map: Grid, equal: bool = False, alpha_3d: float = 0
5256 show_esdf: Whether to show esdf.
5357 alpha_esdf: Alpha of esdf.
5458 '''
59+ self .grid_map = grid_map
60+ self .dim = grid_map .dim
5561 if grid_map .dim == 2 :
5662 plt .imshow (
5763 np .transpose (grid_map .type_map .array ),
@@ -117,7 +123,9 @@ def plot_grid_map(self, grid_map: Grid, equal: bool = False, alpha_3d: float = 0
117123 else :
118124 raise NotImplementedError (f"Grid map with dim={ grid_map .dim } not supported." )
119125
120- def plot_path (self , path : list , style : str = "-" , color : str = "#13ae00" , label : str = None , linewidth : float = 2 , marker : str = None ) -> None :
126+ def plot_path (self , path : List [Union [Tuple [int , ...], Tuple [float , ...]]],
127+ style : str = "-" , color : str = "#13ae00" , label : str = None ,
128+ linewidth : float = 2 , marker : str = None , map_frame : bool = True ) -> None :
121129 '''
122130 Plot path-like information.
123131 The meaning of parameters are similar to matplotlib.pyplot.plot (https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.plot.html).
@@ -129,20 +137,80 @@ def plot_path(self, path: list, style: str = "-", color: str = "#13ae00", label:
129137 label: label of path
130138 linewidth: linewidth of path
131139 marker: marker of path
140+ map_frame: whether path is in map frame or not (world frame)
132141 '''
142+ if map_frame :
143+ path = [self .grid_map .map_to_world (point ) for point in path ]
144+
133145 path = np .array (path )
134- if len (path .shape ) < 2 :
135- return
136- if path .shape [1 ] == 2 :
137- plt .plot (path [:, 0 ], path [:, 1 ], style , lw = linewidth , color = color , label = label , marker = marker )
138- elif path .shape [1 ] == 3 :
146+
147+ if self .dim == 2 :
148+ self .ax .plot (path [:, 0 ], path [:, 1 ], style , lw = linewidth , color = color , label = label , marker = marker )
149+ elif self .dim == 3 :
139150 self .ax .plot (path [:, 0 ], path [:, 1 ], path [:, 2 ], style , lw = linewidth , color = color , label = label , marker = marker )
140151 else :
141- raise ValueError ("Path dimension not supported" )
152+ raise ValueError ("Dimension not supported" )
142153
143154 if label :
144155 self .ax .legend ()
145156
157+ def plot_expand_tree (self , expand_tree : Dict [Union [Tuple [int , ...], Tuple [float , ...]], Node ],
158+ node_color : str = "C5" ,
159+ edge_color : str = "C6" ,
160+ node_size : float = 10 ,
161+ linewidth : float = 1.0 ,
162+ node_alpha : float = 1.0 ,
163+ edge_alpha : float = 1.0 ,
164+ connect_to_parent : bool = True ,
165+ map_frame : bool = True ) -> None :
166+ """
167+ Visualize an expand tree (e.g. RRT).
168+
169+ Args:
170+ expand_tree: Dict mapping coordinate tuple -> Node (world frame).
171+ node_color: Color of the nodes.
172+ edge_color: Color of the edges (parent -> child).
173+ node_size: Size of node markers.
174+ linewidth: Line width of edges.
175+ connect_to_parent: Whether to draw parent-child connections.
176+ map_frame: whether path is in map frame or not (world frame)
177+ """
178+ if self .dim == 2 :
179+ for coord , node in expand_tree .items ():
180+ current = node .current
181+ if map_frame :
182+ current = self .grid_map .map_to_world (current )
183+
184+ self .ax .scatter (current [0 ], current [1 ],
185+ c = node_color , s = node_size , zorder = 3 , alpha = node_alpha )
186+ if connect_to_parent and node .parent is not None :
187+ parent = node .parent
188+ if map_frame :
189+ parent = self .grid_map .map_to_world (parent )
190+ self .ax .plot ([parent [0 ], current [0 ]],
191+ [parent [1 ], current [1 ]],
192+ color = edge_color , linewidth = linewidth , zorder = 2 , alpha = edge_alpha )
193+
194+ elif self .dim == 3 :
195+ for coord , node in expand_tree .items ():
196+ current = node .current
197+ if map_frame :
198+ current = self .grid_map .map_to_world (current )
199+
200+ self .ax .scatter (current [0 ], current [1 ], current [2 ],
201+ c = node_color , s = node_size , zorder = 3 , alpha = node_alpha )
202+ if connect_to_parent and node .parent is not None :
203+ parent = node .parent
204+ if map_frame :
205+ parent = self .grid_map .map_to_world (parent )
206+ self .ax .plot ([parent [0 ], current [0 ]],
207+ [parent [1 ], current [1 ]],
208+ [parent [2 ], current [2 ]],
209+ color = edge_color , linewidth = linewidth , zorder = 2 , alpha = edge_alpha )
210+
211+ else :
212+ raise ValueError ("Dimension must be 2 or 3" )
213+
146214 def plot_circular_robot (self , robot : CircularRobot , axis_equal : bool = True ) -> None :
147215 patch = plt .Circle (tuple (robot .pos ), robot .radius ,
148216 color = robot .color , alpha = robot .alpha , fill = robot .fill , linewidth = robot .linewidth , linestyle = robot .linestyle )
@@ -165,7 +233,8 @@ def plot_circular_robot(self, robot: CircularRobot, axis_equal: bool = True) ->
165233 else :
166234 return patch , text
167235
168- def render_toy_simulator (self , env : ToySimulator , controllers : Dict [str , BaseController ], steps : int = 1000 , interval : int = 50 ,
236+ def render_toy_simulator (self , env : ToySimulator , controllers : Dict [str , BaseController ],
237+ steps : int = 1000 , interval : int = 50 ,
169238 show_traj : bool = True , traj_kwargs : dict = {"linestyle" : '-' , "alpha" : 0.7 , "linewidth" : 1.5 },
170239 show_env_info : bool = False , rtf_limit : float = 1.0 , grid_kwargs : dict = {},
171240 show_pred_traj : bool = True ) -> None :
@@ -281,6 +350,9 @@ def show(self):
281350
282351 def legend (self ):
283352 plt .legend ()
353+
354+ def close (self ):
355+ plt .close ()
284356
285357
286358
0 commit comments