|
| 1 | +# Copyright 2020 Adam Gleave |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Heatmaps for rewards in gridworld environments. |
| 16 | +
|
| 17 | +This is currently only used for illustrative examples in the paper; |
| 18 | +none of the actual experiments are gridworlds.""" |
| 19 | + |
| 20 | +import collections |
| 21 | +import enum |
| 22 | +import math |
| 23 | +from typing import Tuple |
| 24 | +from unittest import mock |
| 25 | + |
| 26 | +import matplotlib |
| 27 | +import matplotlib.collections as mcollections |
| 28 | +import matplotlib.colors as mcolors |
| 29 | +import matplotlib.pyplot as plt |
| 30 | +import mdptoolbox |
| 31 | +import numpy as np |
| 32 | +import seaborn as sns |
| 33 | + |
| 34 | + |
| 35 | +class Actions(enum.IntEnum): |
| 36 | + STAY = 0 |
| 37 | + LEFT = 1 |
| 38 | + UP = 2 |
| 39 | + RIGHT = 3 |
| 40 | + DOWN = 4 |
| 41 | + |
| 42 | + |
| 43 | +# (x,y) offset caused by taking an action |
| 44 | +ACTION_DELTA = { |
| 45 | + Actions.STAY: (0, 0), |
| 46 | + Actions.LEFT: (-1, 0), |
| 47 | + Actions.UP: (0, 1), |
| 48 | + Actions.RIGHT: (1, 0), |
| 49 | + Actions.DOWN: (0, -1), |
| 50 | +} |
| 51 | + |
| 52 | +# Counter-clockwise, corners of a unit square, centred at (0.5, 0.5). |
| 53 | +CORNERS = [(0, 0), (0, 1), (1, 1), (1, 0)] |
| 54 | +# Vertices subdividing the unit square for each action |
| 55 | +OFFSETS = { |
| 56 | + # Triangles, cutting unit-square into quarters |
| 57 | + direction: np.array( |
| 58 | + [CORNERS[action.value - 1], [0.5, 0.5], CORNERS[action.value % len(CORNERS)]] |
| 59 | + ) |
| 60 | + for action, direction in ACTION_DELTA.items() |
| 61 | +} |
| 62 | +# Circle at the center |
| 63 | +OFFSETS[(0, 0)] = np.array([0.5, 0.5]) |
| 64 | + |
| 65 | + |
| 66 | +def shape(state_reward: np.ndarray, state_potential: np.ndarray) -> np.ndarray: |
| 67 | + """Shape `state_reward` with `state_potential`. |
| 68 | +
|
| 69 | + Args: |
| 70 | + state_reward: a two-dimensional array, indexed by `(i,j)`. |
| 71 | + state_potential: a two-dimensional array of the same shape as `state_reward`. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + A state-action reward `sa_reward`. This is three-dimensional array, |
| 75 | + indexed by `(i,j,a)`, where `a` is an action indexing into `DIRECTIONS`. |
| 76 | + `sa_reward[i,j,a] = state_reward[i, j] + state_potential[i', j'] - state_potential[i,j]`, |
| 77 | + where `i', j'` is the successor state after taking action `a`. |
| 78 | + """ |
| 79 | + assert state_reward.ndim == 2 |
| 80 | + assert state_reward.shape == state_potential.shape |
| 81 | + |
| 82 | + padded_reward = np.pad( |
| 83 | + state_reward.astype(np.float32), pad_width=[(1, 1), (1, 1)], constant_values=np.nan |
| 84 | + ) |
| 85 | + padded_potential = np.pad( |
| 86 | + state_potential.astype(np.float32), pad_width=[(1, 1), (1, 1)], constant_values=np.nan |
| 87 | + ) |
| 88 | + |
| 89 | + res = [] |
| 90 | + for x_delta, y_delta in ACTION_DELTA.values(): |
| 91 | + axis = 0 if x_delta else 1 |
| 92 | + delta = x_delta + y_delta |
| 93 | + new_potential = np.roll(padded_potential, -delta, axis=axis) |
| 94 | + shaped = padded_reward + new_potential - padded_potential |
| 95 | + res.append(shaped[1:-1, 1:-1]) |
| 96 | + |
| 97 | + return np.array(res).transpose((1, 2, 0)) |
| 98 | + |
| 99 | + |
| 100 | +def _make_transitions( |
| 101 | + transitions: np.ndarray, |
| 102 | + low_action: int, |
| 103 | + high_action: int, |
| 104 | + states: np.ndarray, |
| 105 | + idx: np.ndarray, |
| 106 | + n: int, |
| 107 | +) -> None: |
| 108 | + transitions[low_action, states[idx == 0], states[idx == 0]] = 1 |
| 109 | + transitions[low_action, states[idx > 0], states[idx < n - 1]] = 1 |
| 110 | + transitions[high_action, states[idx == n - 1], states[idx == n - 1]] = 1 |
| 111 | + transitions[high_action, states[idx < n - 1], states[idx > 0]] = 1 |
| 112 | + |
| 113 | + |
| 114 | +def build_mdp(state_action_reward: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| 115 | + """Create transition matrix for deterministic gridworld and reshape reward.""" |
| 116 | + xlen, ylen, na = state_action_reward.shape |
| 117 | + ns = xlen * ylen |
| 118 | + |
| 119 | + transitions = np.zeros((na, ns, ns)) |
| 120 | + transitions[Actions.STAY.value, :, :] = np.eye(ns, ns) |
| 121 | + states = np.arange(ns) |
| 122 | + xs = states % xlen |
| 123 | + ys = states // ylen |
| 124 | + _make_transitions(transitions, Actions.LEFT.value, Actions.RIGHT.value, states, ys, ylen) |
| 125 | + _make_transitions(transitions, Actions.DOWN.value, Actions.UP.value, states, xs, xlen) |
| 126 | + |
| 127 | + reward = state_action_reward.copy() |
| 128 | + reward = reward.reshape(ns, na) |
| 129 | + # We use NaN for transitions that would go outside the gridworld. |
| 130 | + # But in above transition dynamics these are equivalent to stay, so rewrite. |
| 131 | + mask = np.isnan(reward) |
| 132 | + stay_reward = reward[:, Actions.STAY.value] |
| 133 | + stay_tiled = np.tile(stay_reward, (na, 1)).T |
| 134 | + reward[mask] = stay_tiled[mask] |
| 135 | + assert np.isfinite(reward).all() |
| 136 | + |
| 137 | + return transitions, reward |
| 138 | + |
| 139 | + |
| 140 | +def _no_op_iter(*args, **kwargs): |
| 141 | + """Does nothing, workaround for bug in pymdptoolbox GH#32.""" |
| 142 | + del args, kwargs |
| 143 | + |
| 144 | + |
| 145 | +def compute_qvalues(state_action_reward: np.ndarray, discount: float) -> np.ndarray: |
| 146 | + """Computes the Q-values of `state_action_reward` under deterministic dynamics.""" |
| 147 | + transitions, reward = build_mdp(state_action_reward) |
| 148 | + |
| 149 | + # TODO(adam): remove this workaround once GH pymdptoolbox #32 merged. |
| 150 | + with mock.patch("mdptoolbox.mdp.ValueIteration._boundIter", new=_no_op_iter): |
| 151 | + vi = mdptoolbox.mdp.ValueIteration( |
| 152 | + transitions=transitions, reward=reward, discount=discount |
| 153 | + ) |
| 154 | + vi.run() |
| 155 | + q_values = reward + discount * (transitions * vi.V).sum(2).T |
| 156 | + return q_values |
| 157 | + |
| 158 | + |
| 159 | +def optimal_mask(state_action_reward: np.ndarray, discount: float = 0.99) -> np.ndarray: |
| 160 | + """Computes the optimal actions for each state in `state_action_reward`.""" |
| 161 | + q_values = compute_qvalues(state_action_reward, discount) |
| 162 | + best_q = q_values.max(axis=1)[:, np.newaxis] |
| 163 | + optimal_action = np.isclose(q_values, best_q) |
| 164 | + return optimal_action.reshape(state_action_reward.shape) |
| 165 | + |
| 166 | + |
| 167 | +def _set_ticks(n: int, subaxis: matplotlib.axis.Axis) -> None: |
| 168 | + subaxis.set_ticks(np.arange(0, n + 1), minor=True) |
| 169 | + subaxis.set_ticks(np.arange(n) + 0.5) |
| 170 | + subaxis.set_ticklabels(np.arange(n)) |
| 171 | + |
| 172 | + |
| 173 | +def _reward_make_fig(xlen: int, ylen: int) -> Tuple[plt.Figure, plt.Axes]: |
| 174 | + """Construct figure and set sensible defaults.""" |
| 175 | + fig, ax = plt.subplots(1, 1) |
| 176 | + # Axes limits |
| 177 | + ax.set_xlim(0, xlen) |
| 178 | + ax.set_ylim(0, ylen) |
| 179 | + # Make ticks centred in each cell |
| 180 | + _set_ticks(xlen, ax.xaxis) |
| 181 | + _set_ticks(ylen, ax.yaxis) |
| 182 | + # Draw grid along minor ticks, then remove those ticks so they don't protrude |
| 183 | + ax.grid(which="minor", color="k") |
| 184 | + ax.tick_params(which="minor", length=0, width=0) |
| 185 | + |
| 186 | + return fig, ax |
| 187 | + |
| 188 | + |
| 189 | +def _reward_make_color_map(state_action_reward: np.ndarray) -> matplotlib.cm.ScalarMappable: |
| 190 | + norm = mcolors.Normalize( |
| 191 | + vmin=np.nanmin(state_action_reward), vmax=np.nanmax(state_action_reward) |
| 192 | + ) |
| 193 | + return matplotlib.cm.ScalarMappable(norm=norm) |
| 194 | + |
| 195 | + |
| 196 | +def _reward_draw_spline( |
| 197 | + x: int, |
| 198 | + y: int, |
| 199 | + action: int, |
| 200 | + optimal: bool, |
| 201 | + reward: float, |
| 202 | + from_dest: bool, |
| 203 | + mappable: matplotlib.cm.ScalarMappable, |
| 204 | + annot_padding: float, |
| 205 | + ax: plt.Axes, |
| 206 | +) -> Tuple[np.ndarray, Tuple[float, ...]]: |
| 207 | + # Compute shape position and color |
| 208 | + pos = np.array([x, y]) |
| 209 | + direction = np.array(ACTION_DELTA[action]) |
| 210 | + if from_dest: |
| 211 | + pos = pos + direction |
| 212 | + direction = -direction |
| 213 | + vert = pos + OFFSETS[tuple(direction)] |
| 214 | + color = mappable.to_rgba(reward) |
| 215 | + |
| 216 | + # Add annotation |
| 217 | + text = f"{reward:.0f}" |
| 218 | + lum = sns.utils.relative_luminance(color) |
| 219 | + text_color = ".15" if lum > 0.408 else "w" |
| 220 | + xy = pos + 0.5 |
| 221 | + |
| 222 | + if tuple(direction) != (0, 0): |
| 223 | + xy = xy + annot_padding * direction |
| 224 | + fontweight = "bold" if optimal else None |
| 225 | + ax.annotate( |
| 226 | + text, xy=xy, ha="center", va="center", color=text_color, fontweight=fontweight, |
| 227 | + ) |
| 228 | + |
| 229 | + return vert, color |
| 230 | + |
| 231 | + |
| 232 | +def _reward_draw( |
| 233 | + state_action_reward: np.ndarray, |
| 234 | + fig: plt.Figure, |
| 235 | + ax: plt.Axes, |
| 236 | + mappable: matplotlib.cm.ScalarMappable, |
| 237 | + from_dest: bool, |
| 238 | + edgecolor: str = "gray", |
| 239 | + hatchcolor: str = "white", |
| 240 | +) -> None: |
| 241 | + optimal_actions = optimal_mask(state_action_reward) |
| 242 | + |
| 243 | + circle_area_pt = 200 |
| 244 | + circle_radius_pt = math.sqrt(circle_area_pt / math.pi) |
| 245 | + circle_radius_in = circle_radius_pt / 72 |
| 246 | + corner_display = ax.transData.transform([0.0, 0.0]) |
| 247 | + circle_radius_display = fig.dpi_scale_trans.transform([circle_radius_in, 0]) |
| 248 | + circle_radius_data = ax.transData.inverted().transform(corner_display + circle_radius_display) |
| 249 | + annot_padding = 0.25 + 0.5 * circle_radius_data[0] |
| 250 | + |
| 251 | + verts = collections.defaultdict(lambda: collections.defaultdict(list)) |
| 252 | + colors = collections.defaultdict(lambda: collections.defaultdict(list)) |
| 253 | + |
| 254 | + it = np.nditer(state_action_reward, flags=["multi_index"]) |
| 255 | + while not it.finished: |
| 256 | + reward = it[0] |
| 257 | + x, y, action = it.multi_index |
| 258 | + optimal = optimal_actions[it.multi_index] |
| 259 | + it.iternext() |
| 260 | + |
| 261 | + if not np.isfinite(reward): |
| 262 | + assert action != 0 |
| 263 | + continue |
| 264 | + |
| 265 | + vert, color = _reward_draw_spline( |
| 266 | + x, y, action, optimal, reward, from_dest, mappable, annot_padding, ax |
| 267 | + ) |
| 268 | + |
| 269 | + geom = "circle" if action == 0 else "triangle" |
| 270 | + verts[geom][optimal].append(vert) |
| 271 | + colors[geom][optimal].append(color) |
| 272 | + |
| 273 | + circle_collections = [] |
| 274 | + triangle_collections = [] |
| 275 | + |
| 276 | + def _make_triangle(optimal, **kwargs): |
| 277 | + return mcollections.PolyCollection( |
| 278 | + verts=verts["triangle"][optimal], facecolors=colors["triangle"][optimal], **kwargs, |
| 279 | + ) |
| 280 | + |
| 281 | + def _make_circle(optimal, **kwargs): |
| 282 | + circle_offsets = verts["circle"][optimal] |
| 283 | + return mcollections.CircleCollection( |
| 284 | + sizes=[circle_area_pt] * len(circle_offsets), |
| 285 | + facecolors=colors["circle"][optimal], |
| 286 | + offsets=circle_offsets, |
| 287 | + transOffset=ax.transData, |
| 288 | + **kwargs, |
| 289 | + ) |
| 290 | + |
| 291 | + maker_collection_dict = { |
| 292 | + _make_triangle: triangle_collections, |
| 293 | + _make_circle: circle_collections, |
| 294 | + } |
| 295 | + |
| 296 | + for optimal in [False, True]: |
| 297 | + hatch = "xx" if optimal else None |
| 298 | + |
| 299 | + for maker_fn, cols in maker_collection_dict.items(): |
| 300 | + cols.append(maker_fn(optimal, edgecolors=edgecolor)) |
| 301 | + if hatch: # draw the hatch using a different color |
| 302 | + cols.append(maker_fn(optimal, edgecolors=hatchcolor, linewidth=0, hatch=hatch)) |
| 303 | + |
| 304 | + for cols in triangle_collections + circle_collections: |
| 305 | + ax.add_collection(cols) |
| 306 | + |
| 307 | + |
| 308 | +def plot_gridworld_reward( |
| 309 | + state_action_reward: np.ndarray, |
| 310 | + from_dest: bool = False, |
| 311 | + cbar_format: str = "%.0f", |
| 312 | + cbar_fraction: float = 0.05, |
| 313 | +) -> plt.Figure: |
| 314 | + """ |
| 315 | + Plots a heatmap of reward for the gridworld. |
| 316 | +
|
| 317 | + Args: |
| 318 | + - state_action_reward: a three-dimensional array specifying the gridworld reward. |
| 319 | + - from_dest: if True, the triangular wedges represent reward when arriving into this |
| 320 | + cell from the adjacent cell; if False, represent reward when leaving this cell into |
| 321 | + the adjacent cell. |
| 322 | + - annot_padding: a fraction of a supercell to offset the annotation from the centre. |
| 323 | + - cbar_fraction: the fraction of the axes the colorbar takes up. |
| 324 | +
|
| 325 | + Returns: |
| 326 | + A heatmap consisting of a "supercell" for each state `(i,j)` in the original gridworld. |
| 327 | + This supercell contains a central circle, representing the no-op action reward and four |
| 328 | + triangular wedges, representing the left, up, right and down action rewards. |
| 329 | + """ |
| 330 | + xlen, ylen, num_actions = state_action_reward.shape |
| 331 | + assert num_actions == len(ACTION_DELTA) |
| 332 | + fig, ax = _reward_make_fig(xlen, ylen) |
| 333 | + mappable = _reward_make_color_map(state_action_reward) |
| 334 | + _reward_draw(state_action_reward, fig, ax, mappable, from_dest) |
| 335 | + fig.colorbar(mappable, format=cbar_format, fraction=cbar_fraction) |
| 336 | + return fig |
0 commit comments