1010from pytorch_mppi import mppi
1111import pytorch_seed
1212import logging
13+ import collections
14+
15+ def is_sequence (obj ):
16+ if isinstance (obj , str ):
17+ return False
18+ return isinstance (obj , collections .abc .Sequence )
1319
14- # import window_recorder
1520
1621plt .switch_backend ('Qt5Agg' )
1722
@@ -156,19 +161,33 @@ def start_visualization(self):
156161 # self.draw_start()
157162 self .draw_goal ()
158163
159- def draw_rollouts (self , rollouts , color = "skyblue" , label = None ):
164+ @staticmethod
165+ def get_v (i , values , rollouts ):
166+ if type (values ) != str and is_sequence (values ) and len (values ) == len (rollouts ):
167+ c = values [i ]
168+ else :
169+ c = values
170+ return c
171+
172+ def draw_rollouts (self , rollouts , color = "skyblue" , label = None , end_state_color = "tab:red" , linewidth = 1.5 ):
160173 if not self .visualize :
161174 return
162175 self .clear_artist (self .rollout_artist )
163176 artists = []
164- for rollout in rollouts :
165- # r = torch.cat((self.start.reshape(1, -1), rollout))
177+ for i , rollout in enumerate (rollouts ):
178+ # prepend start state
179+ rollout = torch .cat ([self .state .view (1 , - 1 ), rollout ], dim = 0 )
166180 r = rollout .cpu ()
167181 artists += [self .ax .scatter (r [0 , 0 ], r [0 , 1 ], color = "tab:blue" )]
168- artists += self .ax .plot (r [:, 0 ], r [:, 1 ], color = color , label = label )
169- artists += [self .ax .scatter (r [- 1 , 0 ], r [- 1 , 1 ], color = "tab:red" )]
182+ # if color is a string treat it as a single color, otherwise treat it as a list of colors
183+ artists += self .ax .plot (r [:, 0 ], r [:, 1 ],
184+ color = self .get_v (i , color , rollouts ),
185+ label = self .get_v (i , label , rollouts ),
186+ linewidth = self .get_v (i , linewidth , rollouts ))
187+ artists += [self .ax .scatter (r [- 1 , 0 ], r [- 1 , 1 ], color = self .get_v (i , end_state_color , rollouts ))]
170188 self .rollout_artist = artists
171- self .ax .legend ()
189+ if label is not None :
190+ self .ax .legend (loc = "upper right" )
172191 plt .pause (0.001 )
173192
174193 def draw_trajectory_step (self , prev_state , cur_state , color = "tab:blue" ):
@@ -179,6 +198,7 @@ def draw_trajectory_step(self, prev_state, cur_state, color="tab:blue"):
179198 artists = self .trajectory_artist
180199 artists += self .ax .plot ([prev_state [0 ].cpu (), cur_state [0 ].cpu ()],
181200 [prev_state [1 ].cpu (), cur_state [1 ].cpu ()], color = color )
201+ plt .draw ()
182202 plt .pause (0.001 )
183203
184204 def clear_trajectory (self ):
@@ -265,8 +285,8 @@ def make_gif_ffmpeg(imgs_dir, gif_name, fps=6):
265285 subprocess .run (cmd )
266286
267287
268- def do_control (env , ctrl , ch , seeds = (0 ,), run_steps = 20 , num_refinement_steps = 1 , save_img = True , plot_single = False ):
269- evaluate_running_cost = True
288+ def do_control (env , ctrl , ch , seeds = (0 ,), run_steps = 20 , num_refinement_steps = 1 , save_img = True , plot_single = False ,
289+ evaluate_running_cost = True , plot_trajectory_candidates = False ):
270290 if save_img :
271291 os .makedirs ("images" , exist_ok = True )
272292 os .makedirs ("images/runs" , exist_ok = True )
@@ -307,10 +327,51 @@ def do_control(env, ctrl, ch, seeds=(0,), run_steps=20, num_refinement_steps=1,
307327 for t in range (len (rollout ) - 1 ):
308328 rollout_cost = rollout_cost + env .running_cost (rollout [t ], ctrl .U [t ])
309329 rollout_cost = rollout_cost + env .terminal_cost (rollout , ctrl .U )
310- env .draw_rollouts ([rollout ], label = key )
311330
312331 prev_state = copy .deepcopy (state )
313332 state = env .step (u )
333+
334+ if plot_trajectory_candidates :
335+ from matplotlib import cm
336+ # only plot some candidates rather than all of them
337+ num_candidates = min (10 , ctrl .K )
338+ # for the combined trajectory
339+ color = []
340+ end_color = []
341+ rollouts = []
342+ linewidth = []
343+ labels = []
344+ # for all the candidates
345+ # create matplotlib color map based on cost
346+ cost = ctrl .cost_total .cpu ()
347+ best_idx = torch .argsort (cost )
348+
349+ norm = matplotlib .colors .Normalize (vmin = cost .min (), vmax = cost [best_idx ][:num_candidates * 5 ].max ())
350+ m = cm .ScalarMappable (norm = norm , cmap = cm .jet )
351+
352+ traj_color = m .to_rgba (cost )
353+ # lower alpha
354+ traj_color [:, 3 ] = 0.2
355+
356+ # get rollouts per sampled action trajectory
357+ for j in range (num_candidates ):
358+ idx = best_idx [j ]
359+ this_U = ctrl .actions [0 , idx ]
360+ rollouts .append (ctrl .get_rollouts (state , U = this_U )[0 ])
361+ color .append (traj_color [idx ])
362+ end_color .append ([1 , 0 , 0 , 0.2 ])
363+ linewidth .append (1 )
364+ labels .append (None )
365+ color .append ("skyblue" )
366+ end_color .append ("tab:red" )
367+ rollouts .append (rollout )
368+ linewidth .append (2 )
369+ labels .append (key )
370+ env .draw_rollouts (rollouts , color = color , end_state_color = end_color , linewidth = linewidth , label = labels )
371+ else :
372+ # just draw the single state rollout
373+ env .draw_rollouts ([rollout ], label = key )
374+
314375 env .draw_trajectory_step (prev_state , state )
315376
316377 if save_img :
@@ -500,7 +561,7 @@ def main(plot_only=False):
500561
501562 for ctrl in [mmppi , kmppi , smppi ]:
502563 do_control (env , ctrl , ch , seeds = range (5 ), run_steps = 20 , num_refinement_steps = 1 , save_img = True ,
503- plot_single = False )
564+ plot_single = False , evaluate_running_cost = False , plot_trajectory_candidates = True )
504565
505566
506567if __name__ == "__main__" :
0 commit comments