Skip to content

Commit 8e94091

Browse files
committed
Visualize trajectory samples
1 parent 2140356 commit 8e94091

2 files changed

Lines changed: 75 additions & 15 deletions

File tree

README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,15 @@ trajectory is interpolated using the kernel. The kernel is applied to the contro
108108

109109
MPPI without smoothing
110110

111-
![MPPI](https://imgur.com/aXSo3Ib.gif)
111+
![MPPI](https://imgur.com/9wEcT2s.gif)
112112

113113
[SMPPI](https://arxiv.org/pdf/2112.09988) smoothing by sampling noise in the action derivative space doesn't work well on this problem
114114

115-
![SMPPI](https://imgur.com/y1hvqlD.gif)
115+
![SMPPI](https://imgur.com/xwYy3aj.gif)
116116

117117
KMPPI smoothing with RBF kernel works well
118118

119-
![KMPPI](https://imgur.com/mZmbC4S.gif)
120-
119+
![KMPPI](https://imgur.com/IG1Zrtd.gif)
121120

122121

123122
## Autotune

tests/smooth_mppi.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
from pytorch_mppi import mppi
1111
import pytorch_seed
1212
import logging
13+
import collections
14+
15+
def is_sequence(obj):
16+
if isinstance(obj, str):
17+
return False
18+
return isinstance(obj, collections.abc.Sequence)
1319

14-
# import window_recorder
1520

1621
plt.switch_backend('Qt5Agg')
1722

@@ -156,19 +161,33 @@ def start_visualization(self):
156161
# self.draw_start()
157162
self.draw_goal()
158163

159-
def draw_rollouts(self, rollouts, color="skyblue", label=None):
164+
@staticmethod
165+
def get_v(i, values, rollouts):
166+
if type(values) != str and is_sequence(values) and len(values) == len(rollouts):
167+
c = values[i]
168+
else:
169+
c = values
170+
return c
171+
172+
def draw_rollouts(self, rollouts, color="skyblue", label=None, end_state_color="tab:red", linewidth=1.5):
160173
if not self.visualize:
161174
return
162175
self.clear_artist(self.rollout_artist)
163176
artists = []
164-
for rollout in rollouts:
165-
# r = torch.cat((self.start.reshape(1, -1), rollout))
177+
for i, rollout in enumerate(rollouts):
178+
# prepend start state
179+
rollout = torch.cat([self.state.view(1, -1), rollout], dim=0)
166180
r = rollout.cpu()
167181
artists += [self.ax.scatter(r[0, 0], r[0, 1], color="tab:blue")]
168-
artists += self.ax.plot(r[:, 0], r[:, 1], color=color, label=label)
169-
artists += [self.ax.scatter(r[-1, 0], r[-1, 1], color="tab:red")]
182+
# if color is a string treat it as a single color, otherwise treat it as a list of colors
183+
artists += self.ax.plot(r[:, 0], r[:, 1],
184+
color=self.get_v(i, color, rollouts),
185+
label=self.get_v(i, label, rollouts),
186+
linewidth=self.get_v(i, linewidth, rollouts))
187+
artists += [self.ax.scatter(r[-1, 0], r[-1, 1], color=self.get_v(i, end_state_color, rollouts))]
170188
self.rollout_artist = artists
171-
self.ax.legend()
189+
if label is not None:
190+
self.ax.legend(loc = "upper right")
172191
plt.pause(0.001)
173192

174193
def draw_trajectory_step(self, prev_state, cur_state, color="tab:blue"):
@@ -179,6 +198,7 @@ def draw_trajectory_step(self, prev_state, cur_state, color="tab:blue"):
179198
artists = self.trajectory_artist
180199
artists += self.ax.plot([prev_state[0].cpu(), cur_state[0].cpu()],
181200
[prev_state[1].cpu(), cur_state[1].cpu()], color=color)
201+
plt.draw()
182202
plt.pause(0.001)
183203

184204
def clear_trajectory(self):
@@ -265,8 +285,8 @@ def make_gif_ffmpeg(imgs_dir, gif_name, fps=6):
265285
subprocess.run(cmd)
266286

267287

268-
def do_control(env, ctrl, ch, seeds=(0,), run_steps=20, num_refinement_steps=1, save_img=True, plot_single=False):
269-
evaluate_running_cost = True
288+
def do_control(env, ctrl, ch, seeds=(0,), run_steps=20, num_refinement_steps=1, save_img=True, plot_single=False,
289+
evaluate_running_cost=True, plot_trajectory_candidates=False):
270290
if save_img:
271291
os.makedirs("images", exist_ok=True)
272292
os.makedirs("images/runs", exist_ok=True)
@@ -307,10 +327,51 @@ def do_control(env, ctrl, ch, seeds=(0,), run_steps=20, num_refinement_steps=1,
307327
for t in range(len(rollout) - 1):
308328
rollout_cost = rollout_cost + env.running_cost(rollout[t], ctrl.U[t])
309329
rollout_cost = rollout_cost + env.terminal_cost(rollout, ctrl.U)
310-
env.draw_rollouts([rollout], label=key)
311330

312331
prev_state = copy.deepcopy(state)
313332
state = env.step(u)
333+
334+
if plot_trajectory_candidates:
335+
from matplotlib import cm
336+
# only plot some candidates rather than all of them
337+
num_candidates = min(10, ctrl.K)
338+
# for the combined trajectory
339+
color = []
340+
end_color = []
341+
rollouts = []
342+
linewidth = []
343+
labels = []
344+
# for all the candidates
345+
# create matplotlib color map based on cost
346+
cost = ctrl.cost_total.cpu()
347+
best_idx = torch.argsort(cost)
348+
349+
norm = matplotlib.colors.Normalize(vmin=cost.min(), vmax=cost[best_idx][:num_candidates*5].max())
350+
m = cm.ScalarMappable(norm=norm, cmap=cm.jet)
351+
352+
traj_color = m.to_rgba(cost)
353+
# lower alpha
354+
traj_color[:, 3] = 0.2
355+
356+
# get rollouts per sampled action trajectory
357+
for j in range(num_candidates):
358+
idx = best_idx[j]
359+
this_U = ctrl.actions[0, idx]
360+
rollouts.append(ctrl.get_rollouts(state, U=this_U)[0])
361+
color.append(traj_color[idx])
362+
end_color.append([1, 0, 0, 0.2])
363+
linewidth.append(1)
364+
labels.append(None)
365+
color.append("skyblue")
366+
end_color.append("tab:red")
367+
rollouts.append(rollout)
368+
linewidth.append(2)
369+
labels.append(key)
370+
env.draw_rollouts(rollouts, color=color, end_state_color=end_color, linewidth=linewidth, label=labels)
371+
else:
372+
# just draw the single state rollout
373+
env.draw_rollouts([rollout], label=key)
374+
314375
env.draw_trajectory_step(prev_state, state)
315376

316377
if save_img:
@@ -500,7 +561,7 @@ def main(plot_only=False):
500561

501562
for ctrl in [mmppi, kmppi, smppi]:
502563
do_control(env, ctrl, ch, seeds=range(5), run_steps=20, num_refinement_steps=1, save_img=True,
503-
plot_single=False)
564+
plot_single=False, evaluate_running_cost=False, plot_trajectory_candidates=True)
504565

505566

506567
if __name__ == "__main__":

0 commit comments

Comments
 (0)