Skip to content

Commit 6bdeec5

Browse files
authored
Heatmaps of reward for illustrative gridworlds (#10)
* Prototype gridworld reward heatmaps * Example script to illustrate different figures * Dead code removal, refactor * Add CLI script * Clean up grid and ticks * Tweak formatting * Add some illustrative example configs * Add new configs, bash script to generate all figs, new figure styles * Keep annotations centered across a range of figure sizes * Bold and hatch optimal actions * Add missing dependency * Tweak formatting config * Tweak rewards, figure styles * self review
1 parent 24d90b2 commit 6bdeec5

12 files changed

Lines changed: 602 additions & 7 deletions

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ imitation @ git+https://github.com/HumanCompatibleAI/imitation.git@e99844
33
matplotlib
44
numpy
55
pandas
6+
pymdptoolbox
67
seaborn
78
setuptools
89
scipy
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
# Copyright 2020 Adam Gleave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Heatmaps for rewards in gridworld environments.
16+
17+
This is currently only used for illustrative examples in the paper;
18+
none of the actual experiments are gridworlds."""
19+
20+
import collections
21+
import enum
22+
import math
23+
from typing import Tuple
24+
from unittest import mock
25+
26+
import matplotlib
27+
import matplotlib.collections as mcollections
28+
import matplotlib.colors as mcolors
29+
import matplotlib.pyplot as plt
30+
import mdptoolbox
31+
import numpy as np
32+
import seaborn as sns
33+
34+
35+
class Actions(enum.IntEnum):
36+
STAY = 0
37+
LEFT = 1
38+
UP = 2
39+
RIGHT = 3
40+
DOWN = 4
41+
42+
43+
# (x,y) offset caused by taking an action
44+
ACTION_DELTA = {
45+
Actions.STAY: (0, 0),
46+
Actions.LEFT: (-1, 0),
47+
Actions.UP: (0, 1),
48+
Actions.RIGHT: (1, 0),
49+
Actions.DOWN: (0, -1),
50+
}
51+
52+
# Counter-clockwise, corners of a unit square, centred at (0.5, 0.5).
53+
CORNERS = [(0, 0), (0, 1), (1, 1), (1, 0)]
54+
# Vertices subdividing the unit square for each action
55+
OFFSETS = {
56+
# Triangles, cutting unit-square into quarters
57+
direction: np.array(
58+
[CORNERS[action.value - 1], [0.5, 0.5], CORNERS[action.value % len(CORNERS)]]
59+
)
60+
for action, direction in ACTION_DELTA.items()
61+
}
62+
# Circle at the center
63+
OFFSETS[(0, 0)] = np.array([0.5, 0.5])
64+
65+
66+
def shape(state_reward: np.ndarray, state_potential: np.ndarray) -> np.ndarray:
67+
"""Shape `state_reward` with `state_potential`.
68+
69+
Args:
70+
state_reward: a two-dimensional array, indexed by `(i,j)`.
71+
state_potential: a two-dimensional array of the same shape as `state_reward`.
72+
73+
Returns:
74+
A state-action reward `sa_reward`. This is three-dimensional array,
75+
indexed by `(i,j,a)`, where `a` is an action indexing into `DIRECTIONS`.
76+
`sa_reward[i,j,a] = state_reward[i, j] + state_potential[i', j'] - state_potential[i,j]`,
77+
where `i', j'` is the successor state after taking action `a`.
78+
"""
79+
assert state_reward.ndim == 2
80+
assert state_reward.shape == state_potential.shape
81+
82+
padded_reward = np.pad(
83+
state_reward.astype(np.float32), pad_width=[(1, 1), (1, 1)], constant_values=np.nan
84+
)
85+
padded_potential = np.pad(
86+
state_potential.astype(np.float32), pad_width=[(1, 1), (1, 1)], constant_values=np.nan
87+
)
88+
89+
res = []
90+
for x_delta, y_delta in ACTION_DELTA.values():
91+
axis = 0 if x_delta else 1
92+
delta = x_delta + y_delta
93+
new_potential = np.roll(padded_potential, -delta, axis=axis)
94+
shaped = padded_reward + new_potential - padded_potential
95+
res.append(shaped[1:-1, 1:-1])
96+
97+
return np.array(res).transpose((1, 2, 0))
98+
99+
100+
def _make_transitions(
101+
transitions: np.ndarray,
102+
low_action: int,
103+
high_action: int,
104+
states: np.ndarray,
105+
idx: np.ndarray,
106+
n: int,
107+
) -> None:
108+
transitions[low_action, states[idx == 0], states[idx == 0]] = 1
109+
transitions[low_action, states[idx > 0], states[idx < n - 1]] = 1
110+
transitions[high_action, states[idx == n - 1], states[idx == n - 1]] = 1
111+
transitions[high_action, states[idx < n - 1], states[idx > 0]] = 1
112+
113+
114+
def build_mdp(state_action_reward: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
115+
"""Create transition matrix for deterministic gridworld and reshape reward."""
116+
xlen, ylen, na = state_action_reward.shape
117+
ns = xlen * ylen
118+
119+
transitions = np.zeros((na, ns, ns))
120+
transitions[Actions.STAY.value, :, :] = np.eye(ns, ns)
121+
states = np.arange(ns)
122+
xs = states % xlen
123+
ys = states // ylen
124+
_make_transitions(transitions, Actions.LEFT.value, Actions.RIGHT.value, states, ys, ylen)
125+
_make_transitions(transitions, Actions.DOWN.value, Actions.UP.value, states, xs, xlen)
126+
127+
reward = state_action_reward.copy()
128+
reward = reward.reshape(ns, na)
129+
# We use NaN for transitions that would go outside the gridworld.
130+
# But in above transition dynamics these are equivalent to stay, so rewrite.
131+
mask = np.isnan(reward)
132+
stay_reward = reward[:, Actions.STAY.value]
133+
stay_tiled = np.tile(stay_reward, (na, 1)).T
134+
reward[mask] = stay_tiled[mask]
135+
assert np.isfinite(reward).all()
136+
137+
return transitions, reward
138+
139+
140+
def _no_op_iter(*args, **kwargs):
141+
"""Does nothing, workaround for bug in pymdptoolbox GH#32."""
142+
del args, kwargs
143+
144+
145+
def compute_qvalues(state_action_reward: np.ndarray, discount: float) -> np.ndarray:
146+
"""Computes the Q-values of `state_action_reward` under deterministic dynamics."""
147+
transitions, reward = build_mdp(state_action_reward)
148+
149+
# TODO(adam): remove this workaround once GH pymdptoolbox #32 merged.
150+
with mock.patch("mdptoolbox.mdp.ValueIteration._boundIter", new=_no_op_iter):
151+
vi = mdptoolbox.mdp.ValueIteration(
152+
transitions=transitions, reward=reward, discount=discount
153+
)
154+
vi.run()
155+
q_values = reward + discount * (transitions * vi.V).sum(2).T
156+
return q_values
157+
158+
159+
def optimal_mask(state_action_reward: np.ndarray, discount: float = 0.99) -> np.ndarray:
160+
"""Computes the optimal actions for each state in `state_action_reward`."""
161+
q_values = compute_qvalues(state_action_reward, discount)
162+
best_q = q_values.max(axis=1)[:, np.newaxis]
163+
optimal_action = np.isclose(q_values, best_q)
164+
return optimal_action.reshape(state_action_reward.shape)
165+
166+
167+
def _set_ticks(n: int, subaxis: matplotlib.axis.Axis) -> None:
168+
subaxis.set_ticks(np.arange(0, n + 1), minor=True)
169+
subaxis.set_ticks(np.arange(n) + 0.5)
170+
subaxis.set_ticklabels(np.arange(n))
171+
172+
173+
def _reward_make_fig(xlen: int, ylen: int) -> Tuple[plt.Figure, plt.Axes]:
174+
"""Construct figure and set sensible defaults."""
175+
fig, ax = plt.subplots(1, 1)
176+
# Axes limits
177+
ax.set_xlim(0, xlen)
178+
ax.set_ylim(0, ylen)
179+
# Make ticks centred in each cell
180+
_set_ticks(xlen, ax.xaxis)
181+
_set_ticks(ylen, ax.yaxis)
182+
# Draw grid along minor ticks, then remove those ticks so they don't protrude
183+
ax.grid(which="minor", color="k")
184+
ax.tick_params(which="minor", length=0, width=0)
185+
186+
return fig, ax
187+
188+
189+
def _reward_make_color_map(state_action_reward: np.ndarray) -> matplotlib.cm.ScalarMappable:
190+
norm = mcolors.Normalize(
191+
vmin=np.nanmin(state_action_reward), vmax=np.nanmax(state_action_reward)
192+
)
193+
return matplotlib.cm.ScalarMappable(norm=norm)
194+
195+
196+
def _reward_draw_spline(
197+
x: int,
198+
y: int,
199+
action: int,
200+
optimal: bool,
201+
reward: float,
202+
from_dest: bool,
203+
mappable: matplotlib.cm.ScalarMappable,
204+
annot_padding: float,
205+
ax: plt.Axes,
206+
) -> Tuple[np.ndarray, Tuple[float, ...]]:
207+
# Compute shape position and color
208+
pos = np.array([x, y])
209+
direction = np.array(ACTION_DELTA[action])
210+
if from_dest:
211+
pos = pos + direction
212+
direction = -direction
213+
vert = pos + OFFSETS[tuple(direction)]
214+
color = mappable.to_rgba(reward)
215+
216+
# Add annotation
217+
text = f"{reward:.0f}"
218+
lum = sns.utils.relative_luminance(color)
219+
text_color = ".15" if lum > 0.408 else "w"
220+
xy = pos + 0.5
221+
222+
if tuple(direction) != (0, 0):
223+
xy = xy + annot_padding * direction
224+
fontweight = "bold" if optimal else None
225+
ax.annotate(
226+
text, xy=xy, ha="center", va="center", color=text_color, fontweight=fontweight,
227+
)
228+
229+
return vert, color
230+
231+
232+
def _reward_draw(
233+
state_action_reward: np.ndarray,
234+
fig: plt.Figure,
235+
ax: plt.Axes,
236+
mappable: matplotlib.cm.ScalarMappable,
237+
from_dest: bool,
238+
edgecolor: str = "gray",
239+
hatchcolor: str = "white",
240+
) -> None:
241+
optimal_actions = optimal_mask(state_action_reward)
242+
243+
circle_area_pt = 200
244+
circle_radius_pt = math.sqrt(circle_area_pt / math.pi)
245+
circle_radius_in = circle_radius_pt / 72
246+
corner_display = ax.transData.transform([0.0, 0.0])
247+
circle_radius_display = fig.dpi_scale_trans.transform([circle_radius_in, 0])
248+
circle_radius_data = ax.transData.inverted().transform(corner_display + circle_radius_display)
249+
annot_padding = 0.25 + 0.5 * circle_radius_data[0]
250+
251+
verts = collections.defaultdict(lambda: collections.defaultdict(list))
252+
colors = collections.defaultdict(lambda: collections.defaultdict(list))
253+
254+
it = np.nditer(state_action_reward, flags=["multi_index"])
255+
while not it.finished:
256+
reward = it[0]
257+
x, y, action = it.multi_index
258+
optimal = optimal_actions[it.multi_index]
259+
it.iternext()
260+
261+
if not np.isfinite(reward):
262+
assert action != 0
263+
continue
264+
265+
vert, color = _reward_draw_spline(
266+
x, y, action, optimal, reward, from_dest, mappable, annot_padding, ax
267+
)
268+
269+
geom = "circle" if action == 0 else "triangle"
270+
verts[geom][optimal].append(vert)
271+
colors[geom][optimal].append(color)
272+
273+
circle_collections = []
274+
triangle_collections = []
275+
276+
def _make_triangle(optimal, **kwargs):
277+
return mcollections.PolyCollection(
278+
verts=verts["triangle"][optimal], facecolors=colors["triangle"][optimal], **kwargs,
279+
)
280+
281+
def _make_circle(optimal, **kwargs):
282+
circle_offsets = verts["circle"][optimal]
283+
return mcollections.CircleCollection(
284+
sizes=[circle_area_pt] * len(circle_offsets),
285+
facecolors=colors["circle"][optimal],
286+
offsets=circle_offsets,
287+
transOffset=ax.transData,
288+
**kwargs,
289+
)
290+
291+
maker_collection_dict = {
292+
_make_triangle: triangle_collections,
293+
_make_circle: circle_collections,
294+
}
295+
296+
for optimal in [False, True]:
297+
hatch = "xx" if optimal else None
298+
299+
for maker_fn, cols in maker_collection_dict.items():
300+
cols.append(maker_fn(optimal, edgecolors=edgecolor))
301+
if hatch: # draw the hatch using a different color
302+
cols.append(maker_fn(optimal, edgecolors=hatchcolor, linewidth=0, hatch=hatch))
303+
304+
for cols in triangle_collections + circle_collections:
305+
ax.add_collection(cols)
306+
307+
308+
def plot_gridworld_reward(
309+
state_action_reward: np.ndarray,
310+
from_dest: bool = False,
311+
cbar_format: str = "%.0f",
312+
cbar_fraction: float = 0.05,
313+
) -> plt.Figure:
314+
"""
315+
Plots a heatmap of reward for the gridworld.
316+
317+
Args:
318+
- state_action_reward: a three-dimensional array specifying the gridworld reward.
319+
- from_dest: if True, the triangular wedges represent reward when arriving into this
320+
cell from the adjacent cell; if False, represent reward when leaving this cell into
321+
the adjacent cell.
322+
- annot_padding: a fraction of a supercell to offset the annotation from the centre.
323+
- cbar_fraction: the fraction of the axes the colorbar takes up.
324+
325+
Returns:
326+
A heatmap consisting of a "supercell" for each state `(i,j)` in the original gridworld.
327+
This supercell contains a central circle, representing the no-op action reward and four
328+
triangular wedges, representing the left, up, right and down action rewards.
329+
"""
330+
xlen, ylen, num_actions = state_action_reward.shape
331+
assert num_actions == len(ACTION_DELTA)
332+
fig, ax = _reward_make_fig(xlen, ylen)
333+
mappable = _reward_make_color_map(state_action_reward)
334+
_reward_draw(state_action_reward, fig, ax, mappable, from_dest)
335+
fig.colorbar(mappable, format=cbar_format, fraction=cbar_fraction)
336+
return fig
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
3+
CONFIGS="sparse_goal sparse_goal_shift sparse_goal_scale \
4+
dense_goal antidense_goal transformed_goal \
5+
obstacle_course cliff_walk sparse_anti_goal all_zero"
6+
7+
parallel --header : python -m evaluating_rewards.analysis.plot_gridworld_heatmap \
8+
with {config} \
9+
::: config ${CONFIGS}

src/evaluating_rewards/analysis/plot_divergence_heatmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def logging_config(log_root, search):
7373

7474

7575
@plot_divergence_heatmap_ex.named_config
76-
def fast():
76+
def test():
7777
"""Intended for debugging/unit test."""
7878
data_root = os.path.join("tests", "data")
7979
data_subdir = "comparison"

0 commit comments

Comments
 (0)