Skip to content

Commit 4c3db3a

Browse files
authored
Divergence of gridworld rewards and reward heatmap improvements (#11)
* Move rewards to separate config file * tabular: handle affine transforms, support discounting * Add CLI script to plot divergence of gridworld rewards * Paper quality gridworld divergence figure * Bugfix: handle None order * add type annotations * Disable TeX in unit test * Make hatch background color depend on foreground * Consistent color scale between reward plots * Add center goal * Tweaks to figure * Add accidentally deleted bash script
1 parent 6bdeec5 commit 4c3db3a

12 files changed

Lines changed: 537 additions & 307 deletions

src/evaluating_rewards/analysis/gridworld_heatmap.py

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717
This is currently only used for illustrative examples in the paper;
1818
none of the actual experiments are gridworlds."""
1919

20-
import collections
2120
import enum
21+
import functools
2222
import math
23-
from typing import Tuple
23+
from typing import Optional, Tuple
2424
from unittest import mock
2525

2626
import matplotlib
27-
import matplotlib.collections as mcollections
2827
import matplotlib.colors as mcolors
28+
import matplotlib.patches as mpatches
2929
import matplotlib.pyplot as plt
3030
import mdptoolbox
3131
import numpy as np
@@ -63,12 +63,15 @@ class Actions(enum.IntEnum):
6363
OFFSETS[(0, 0)] = np.array([0.5, 0.5])
6464

6565

66-
def shape(state_reward: np.ndarray, state_potential: np.ndarray) -> np.ndarray:
66+
def shape(
67+
state_reward: np.ndarray, state_potential: np.ndarray, discount: float = 0.99
68+
) -> np.ndarray:
6769
"""Shape `state_reward` with `state_potential`.
6870
6971
Args:
7072
state_reward: a two-dimensional array, indexed by `(i,j)`.
7173
state_potential: a two-dimensional array of the same shape as `state_reward`.
74+
discount: discount rate of MDP.
7275
7376
Returns:
7477
A state-action reward `sa_reward`. This is three-dimensional array,
@@ -90,7 +93,7 @@ def shape(state_reward: np.ndarray, state_potential: np.ndarray) -> np.ndarray:
9093
for x_delta, y_delta in ACTION_DELTA.values():
9194
axis = 0 if x_delta else 1
9295
delta = x_delta + y_delta
93-
new_potential = np.roll(padded_potential, -delta, axis=axis)
96+
new_potential = discount * np.roll(padded_potential, -delta, axis=axis)
9497
shaped = padded_reward + new_potential - padded_potential
9598
res.append(shaped[1:-1, 1:-1])
9699

@@ -186,10 +189,14 @@ def _reward_make_fig(xlen: int, ylen: int) -> Tuple[plt.Figure, plt.Axes]:
186189
return fig, ax
187190

188191

189-
def _reward_make_color_map(state_action_reward: np.ndarray) -> matplotlib.cm.ScalarMappable:
190-
norm = mcolors.Normalize(
191-
vmin=np.nanmin(state_action_reward), vmax=np.nanmax(state_action_reward)
192-
)
192+
def _reward_make_color_map(
193+
state_action_reward: np.ndarray, vmin: Optional[float], vmax: Optional[float]
194+
) -> matplotlib.cm.ScalarMappable:
195+
if vmin is None:
196+
vmin = np.nanmin(state_action_reward)
197+
if vmax is None:
198+
vmax = np.nanmin(state_action_reward)
199+
norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
193200
return matplotlib.cm.ScalarMappable(norm=norm)
194201

195202

@@ -203,7 +210,7 @@ def _reward_draw_spline(
203210
mappable: matplotlib.cm.ScalarMappable,
204211
annot_padding: float,
205212
ax: plt.Axes,
206-
) -> Tuple[np.ndarray, Tuple[float, ...]]:
213+
) -> Tuple[np.ndarray, Tuple[float, ...], str]:
207214
# Compute shape position and color
208215
pos = np.array([x, y])
209216
direction = np.array(ACTION_DELTA[action])
@@ -217,6 +224,7 @@ def _reward_draw_spline(
217224
text = f"{reward:.0f}"
218225
lum = sns.utils.relative_luminance(color)
219226
text_color = ".15" if lum > 0.408 else "w"
227+
hatch_color = ".5" if lum > 0.408 else "w"
220228
xy = pos + 0.5
221229

222230
if tuple(direction) != (0, 0):
@@ -226,19 +234,27 @@ def _reward_draw_spline(
226234
text, xy=xy, ha="center", va="center", color=text_color, fontweight=fontweight,
227235
)
228236

229-
return vert, color
237+
return vert, color, hatch_color
238+
239+
240+
def _make_triangle(vert, color, **kwargs):
241+
return mpatches.Polygon(xy=vert, facecolor=color, **kwargs)
242+
243+
244+
def _make_circle(vert, color, radius, **kwargs):
245+
return mpatches.Circle(xy=vert, radius=radius, facecolor=color, **kwargs)
230246

231247

232248
def _reward_draw(
233249
state_action_reward: np.ndarray,
250+
discount: float,
234251
fig: plt.Figure,
235252
ax: plt.Axes,
236253
mappable: matplotlib.cm.ScalarMappable,
237254
from_dest: bool,
238255
edgecolor: str = "gray",
239-
hatchcolor: str = "white",
240256
) -> None:
241-
optimal_actions = optimal_mask(state_action_reward)
257+
optimal_actions = optimal_mask(state_action_reward, discount)
242258

243259
circle_area_pt = 200
244260
circle_radius_pt = math.sqrt(circle_area_pt / math.pi)
@@ -248,8 +264,8 @@ def _reward_draw(
248264
circle_radius_data = ax.transData.inverted().transform(corner_display + circle_radius_display)
249265
annot_padding = 0.25 + 0.5 * circle_radius_data[0]
250266

251-
verts = collections.defaultdict(lambda: collections.defaultdict(list))
252-
colors = collections.defaultdict(lambda: collections.defaultdict(list))
267+
triangle_patches = []
268+
circle_patches = []
253269

254270
it = np.nditer(state_action_reward, flags=["multi_index"])
255271
while not it.finished:
@@ -262,54 +278,35 @@ def _reward_draw(
262278
assert action != 0
263279
continue
264280

265-
vert, color = _reward_draw_spline(
281+
vert, color, hatch_color = _reward_draw_spline(
266282
x, y, action, optimal, reward, from_dest, mappable, annot_padding, ax
267283
)
268284

269-
geom = "circle" if action == 0 else "triangle"
270-
verts[geom][optimal].append(vert)
271-
colors[geom][optimal].append(color)
272-
273-
circle_collections = []
274-
triangle_collections = []
275-
276-
def _make_triangle(optimal, **kwargs):
277-
return mcollections.PolyCollection(
278-
verts=verts["triangle"][optimal], facecolors=colors["triangle"][optimal], **kwargs,
279-
)
280-
281-
def _make_circle(optimal, **kwargs):
282-
circle_offsets = verts["circle"][optimal]
283-
return mcollections.CircleCollection(
284-
sizes=[circle_area_pt] * len(circle_offsets),
285-
facecolors=colors["circle"][optimal],
286-
offsets=circle_offsets,
287-
transOffset=ax.transData,
288-
**kwargs,
289-
)
290-
291-
maker_collection_dict = {
292-
_make_triangle: triangle_collections,
293-
_make_circle: circle_collections,
294-
}
295-
296-
for optimal in [False, True]:
297285
hatch = "xx" if optimal else None
298-
299-
for maker_fn, cols in maker_collection_dict.items():
300-
cols.append(maker_fn(optimal, edgecolors=edgecolor))
301-
if hatch: # draw the hatch using a different color
302-
cols.append(maker_fn(optimal, edgecolors=hatchcolor, linewidth=0, hatch=hatch))
303-
304-
for cols in triangle_collections + circle_collections:
305-
ax.add_collection(cols)
286+
if action == 0:
287+
fn = functools.partial(_make_circle, radius=circle_radius_data[0])
288+
else:
289+
fn = _make_triangle
290+
patches = circle_patches if action == 0 else triangle_patches
291+
if hatch: # draw the hatch using a different color
292+
patches.append(fn(vert, tuple(color), linewidth=1, edgecolor=hatch_color, hatch=hatch))
293+
patches.append(fn(vert, tuple(color), linewidth=1, edgecolor=edgecolor, fill=False))
294+
else:
295+
patches.append(fn(vert, tuple(color), linewidth=1, edgecolor=edgecolor))
296+
297+
for p in triangle_patches + circle_patches:
298+
# need to draw circles on top of triangles
299+
ax.add_patch(p)
306300

307301

308302
def plot_gridworld_reward(
309303
state_action_reward: np.ndarray,
304+
discount: float = 0.99,
310305
from_dest: bool = False,
311306
cbar_format: str = "%.0f",
312307
cbar_fraction: float = 0.05,
308+
vmin: Optional[float] = None,
309+
vmax: Optional[float] = None,
313310
) -> plt.Figure:
314311
"""
315312
Plots a heatmap of reward for the gridworld.
@@ -330,7 +327,7 @@ def plot_gridworld_reward(
330327
xlen, ylen, num_actions = state_action_reward.shape
331328
assert num_actions == len(ACTION_DELTA)
332329
fig, ax = _reward_make_fig(xlen, ylen)
333-
mappable = _reward_make_color_map(state_action_reward)
334-
_reward_draw(state_action_reward, fig, ax, mappable, from_dest)
330+
mappable = _reward_make_color_map(state_action_reward, vmin, vmax)
331+
_reward_draw(state_action_reward, discount, fig, ax, mappable, from_dest)
335332
fig.colorbar(mappable, format=cbar_format, fraction=cbar_fraction)
336333
return fig
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Illustrative rewards for gridworlds."""
2+
3+
import numpy as np
4+
5+
SPARSE_GOAL = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 1]])
6+
7+
CENTER_GOAL = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]])
8+
9+
OBSTACLE_COURSE = np.array([[0, -1, -1], [0, 0, 0], [-1, -1, 5]])
10+
11+
CLIFF_WALK = np.array([[0, -1, -1], [0, 0, 0], [-5, -5, 5]])
12+
13+
MANHATTAN_FROM_GOAL = np.array([[4, 3, 2], [3, 2, 1], [2, 1, 0]])
14+
15+
ZERO = np.zeros((3, 3))
16+
17+
REWARDS = {
18+
# Equivalent rewards
19+
"sparse_goal": {"state_reward": SPARSE_GOAL, "potential": ZERO},
20+
"sparse_goal_shift": {"state_reward": SPARSE_GOAL + 1, "potential": ZERO},
21+
"sparse_goal_scale": {"state_reward": SPARSE_GOAL * 10, "potential": ZERO},
22+
"dense_goal": {"state_reward": SPARSE_GOAL, "potential": -MANHATTAN_FROM_GOAL},
23+
"antidense_goal": {"state_reward": SPARSE_GOAL, "potential": MANHATTAN_FROM_GOAL},
24+
# Non-equivalent rewards
25+
"transformed_goal": {
26+
# Shifted, rescaled and reshaped sparse goal.
27+
"state_reward": SPARSE_GOAL * 4 - 1,
28+
"potential": -MANHATTAN_FROM_GOAL * 4,
29+
},
30+
"center_goal": {
31+
# Goal is in center
32+
"state_reward": CENTER_GOAL,
33+
"potential": ZERO,
34+
},
35+
"dirt_path": {
36+
# Some minor penalties to avoid to reach goal.
37+
#
38+
# Optimal policy for this is optimal in `SPARSE_GOAL`, but not equivalent.
39+
# Think may come apart in some dynamics but not particularly intuitively.
40+
"state_reward": OBSTACLE_COURSE,
41+
"potential": ZERO,
42+
},
43+
"cliff_walk": {
44+
# Avoid cliff to reach goal. Same set of optimal policies as `obstacle_course` in
45+
# deterministic dynamics, but not equivalent.
46+
#
47+
# Optimal policy differs in sufficiently slippery gridworlds as want to stay on top line
48+
# to avoid chance of falling off cliff.
49+
"state_reward": CLIFF_WALK,
50+
"potential": ZERO,
51+
},
52+
"sparse_penalty": {
53+
# Negative of `sparse_goal`.
54+
"state_reward": -SPARSE_GOAL,
55+
"potential": ZERO,
56+
},
57+
"all_zero": {
58+
# All zero reward function
59+
"state_reward": ZERO,
60+
"potential": ZERO,
61+
},
62+
}

src/evaluating_rewards/analysis/latex/figsymbols.sty

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,11 @@
1515
\newcommand{\sparse}{\texttt{S}}
1616
\newcommand{\dense}{\texttt{D}}
1717
\newcommand{\magnitude}{\texttt{M}}
18-
\newcommand{\zeroreward}{\texttt{Zero}}
18+
\newcommand{\zeroreward}{\texttt{Zero}}
19+
20+
\newcommand{\sparsegoal}{\texttt{Sparse}}
21+
\newcommand{\densegoal}{\texttt{Dense}}
22+
\newcommand{\sparsepenalty}{\texttt{Penalty}}
23+
\newcommand{\centergoal}{\texttt{Center}}
24+
\newcommand{\dirtpath}{\texttt{Path}}
25+
\newcommand{\cliffwalk}{\texttt{Cliff}}

src/evaluating_rewards/analysis/plot_all_gridworld_heatmap.sh renamed to src/evaluating_rewards/analysis/plot_all_gridworld_reward.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#!/bin/bash
22

33
CONFIGS="sparse_goal sparse_goal_shift sparse_goal_scale \
4-
dense_goal antidense_goal transformed_goal \
5-
obstacle_course cliff_walk sparse_anti_goal all_zero"
4+
dense_goal antidense_goal transformed_goal center_goal \
5+
dirt_path cliff_walk sparse_penalty all_zero"
66

7-
parallel --header : python -m evaluating_rewards.analysis.plot_gridworld_heatmap \
7+
parallel --header : python -m evaluating_rewards.analysis.plot_gridworld_reward \
88
with {config} \
99
::: config ${CONFIGS}

src/evaluating_rewards/analysis/plot_divergence_heatmap.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131

3232
def horizontal_ticks() -> None:
33-
# lazy import to allow custom backend
3433
plt.xticks(rotation="horizontal")
3534
plt.yticks(rotation="horizontal")
3635

@@ -50,10 +49,9 @@ def default_config():
5049
# Figure parameters
5150
heatmap_kwargs = {
5251
"masks": {"all": [visualize.always_true]},
53-
"order": None,
5452
"after_plot": horizontal_ticks,
5553
}
56-
styles = ["paper", "heatmap-1col", "tex"]
54+
styles = ["paper", "heatmap", "heatmap-1col", "tex"]
5755
save_kwargs = {
5856
"fmt": "pdf",
5957
}
@@ -239,9 +237,6 @@ def cfg_filter(cfg):
239237
stats = results.load_multiple_stats(data_dir, keys, cfg_filter=cfg_filter)
240238
res = results.pipeline(stats)
241239
loss = res["loss"]["loss"]
242-
heatmap_kwargs = dict(heatmap_kwargs)
243-
if heatmap_kwargs.get("order") is None:
244-
heatmap_kwargs["order"] = loss.index.levels[0]
245240

246241
figs = {}
247242
figs["loss"] = visualize.loss_heatmap(loss, res["loss"]["unwrapped_loss"])

0 commit comments

Comments
 (0)