1717This is currently only used for illustrative examples in the paper;
1818none of the actual experiments are gridworlds."""
1919
20- import collections
2120import enum
21+ import functools
2222import math
23- from typing import Tuple
23+ from typing import Optional , Tuple
2424from unittest import mock
2525
2626import matplotlib
27- import matplotlib .collections as mcollections
2827import matplotlib .colors as mcolors
28+ import matplotlib .patches as mpatches
2929import matplotlib .pyplot as plt
3030import mdptoolbox
3131import numpy as np
@@ -63,12 +63,15 @@ class Actions(enum.IntEnum):
6363OFFSETS [(0 , 0 )] = np .array ([0.5 , 0.5 ])
6464
6565
66- def shape (state_reward : np .ndarray , state_potential : np .ndarray ) -> np .ndarray :
66+ def shape (
67+ state_reward : np .ndarray , state_potential : np .ndarray , discount : float = 0.99
68+ ) -> np .ndarray :
6769 """Shape `state_reward` with `state_potential`.
6870
6971 Args:
7072 state_reward: a two-dimensional array, indexed by `(i,j)`.
7173 state_potential: a two-dimensional array of the same shape as `state_reward`.
74+ discount: discount rate of MDP.
7275
7376 Returns:
7477 A state-action reward `sa_reward`. This is three-dimensional array,
@@ -90,7 +93,7 @@ def shape(state_reward: np.ndarray, state_potential: np.ndarray) -> np.ndarray:
9093 for x_delta , y_delta in ACTION_DELTA .values ():
9194 axis = 0 if x_delta else 1
9295 delta = x_delta + y_delta
93- new_potential = np .roll (padded_potential , - delta , axis = axis )
96+ new_potential = discount * np .roll (padded_potential , - delta , axis = axis )
9497 shaped = padded_reward + new_potential - padded_potential
9598 res .append (shaped [1 :- 1 , 1 :- 1 ])
9699
@@ -186,10 +189,14 @@ def _reward_make_fig(xlen: int, ylen: int) -> Tuple[plt.Figure, plt.Axes]:
186189 return fig , ax
187190
188191
189- def _reward_make_color_map (state_action_reward : np .ndarray ) -> matplotlib .cm .ScalarMappable :
190- norm = mcolors .Normalize (
191- vmin = np .nanmin (state_action_reward ), vmax = np .nanmax (state_action_reward )
192- )
192+ def _reward_make_color_map (
193+ state_action_reward : np .ndarray , vmin : Optional [float ], vmax : Optional [float ]
194+ ) -> matplotlib .cm .ScalarMappable :
195+ if vmin is None :
196+ vmin = np .nanmin (state_action_reward )
197+ if vmax is None :
198+ vmax = np .nanmin (state_action_reward )
199+ norm = mcolors .Normalize (vmin = vmin , vmax = vmax )
193200 return matplotlib .cm .ScalarMappable (norm = norm )
194201
195202
@@ -203,7 +210,7 @@ def _reward_draw_spline(
203210 mappable : matplotlib .cm .ScalarMappable ,
204211 annot_padding : float ,
205212 ax : plt .Axes ,
206- ) -> Tuple [np .ndarray , Tuple [float , ...]]:
213+ ) -> Tuple [np .ndarray , Tuple [float , ...], str ]:
207214 # Compute shape position and color
208215 pos = np .array ([x , y ])
209216 direction = np .array (ACTION_DELTA [action ])
@@ -217,6 +224,7 @@ def _reward_draw_spline(
217224 text = f"{ reward :.0f} "
218225 lum = sns .utils .relative_luminance (color )
219226 text_color = ".15" if lum > 0.408 else "w"
227+ hatch_color = ".5" if lum > 0.408 else "w"
220228 xy = pos + 0.5
221229
222230 if tuple (direction ) != (0 , 0 ):
@@ -226,19 +234,27 @@ def _reward_draw_spline(
226234 text , xy = xy , ha = "center" , va = "center" , color = text_color , fontweight = fontweight ,
227235 )
228236
229- return vert , color
237+ return vert , color , hatch_color
238+
239+
240+ def _make_triangle (vert , color , ** kwargs ):
241+ return mpatches .Polygon (xy = vert , facecolor = color , ** kwargs )
242+
243+
244+ def _make_circle (vert , color , radius , ** kwargs ):
245+ return mpatches .Circle (xy = vert , radius = radius , facecolor = color , ** kwargs )
230246
231247
232248def _reward_draw (
233249 state_action_reward : np .ndarray ,
250+ discount : float ,
234251 fig : plt .Figure ,
235252 ax : plt .Axes ,
236253 mappable : matplotlib .cm .ScalarMappable ,
237254 from_dest : bool ,
238255 edgecolor : str = "gray" ,
239- hatchcolor : str = "white" ,
240256) -> None :
241- optimal_actions = optimal_mask (state_action_reward )
257+ optimal_actions = optimal_mask (state_action_reward , discount )
242258
243259 circle_area_pt = 200
244260 circle_radius_pt = math .sqrt (circle_area_pt / math .pi )
@@ -248,8 +264,8 @@ def _reward_draw(
248264 circle_radius_data = ax .transData .inverted ().transform (corner_display + circle_radius_display )
249265 annot_padding = 0.25 + 0.5 * circle_radius_data [0 ]
250266
251- verts = collections . defaultdict ( lambda : collections . defaultdict ( list ))
252- colors = collections . defaultdict ( lambda : collections . defaultdict ( list ))
267+ triangle_patches = []
268+ circle_patches = []
253269
254270 it = np .nditer (state_action_reward , flags = ["multi_index" ])
255271 while not it .finished :
@@ -262,54 +278,35 @@ def _reward_draw(
262278 assert action != 0
263279 continue
264280
265- vert , color = _reward_draw_spline (
281+ vert , color , hatch_color = _reward_draw_spline (
266282 x , y , action , optimal , reward , from_dest , mappable , annot_padding , ax
267283 )
268284
269- geom = "circle" if action == 0 else "triangle"
270- verts [geom ][optimal ].append (vert )
271- colors [geom ][optimal ].append (color )
272-
273- circle_collections = []
274- triangle_collections = []
275-
276- def _make_triangle (optimal , ** kwargs ):
277- return mcollections .PolyCollection (
278- verts = verts ["triangle" ][optimal ], facecolors = colors ["triangle" ][optimal ], ** kwargs ,
279- )
280-
281- def _make_circle (optimal , ** kwargs ):
282- circle_offsets = verts ["circle" ][optimal ]
283- return mcollections .CircleCollection (
284- sizes = [circle_area_pt ] * len (circle_offsets ),
285- facecolors = colors ["circle" ][optimal ],
286- offsets = circle_offsets ,
287- transOffset = ax .transData ,
288- ** kwargs ,
289- )
290-
291- maker_collection_dict = {
292- _make_triangle : triangle_collections ,
293- _make_circle : circle_collections ,
294- }
295-
296- for optimal in [False , True ]:
297285 hatch = "xx" if optimal else None
298-
299- for maker_fn , cols in maker_collection_dict .items ():
300- cols .append (maker_fn (optimal , edgecolors = edgecolor ))
301- if hatch : # draw the hatch using a different color
302- cols .append (maker_fn (optimal , edgecolors = hatchcolor , linewidth = 0 , hatch = hatch ))
303-
304- for cols in triangle_collections + circle_collections :
305- ax .add_collection (cols )
286+ if action == 0 :
287+ fn = functools .partial (_make_circle , radius = circle_radius_data [0 ])
288+ else :
289+ fn = _make_triangle
290+ patches = circle_patches if action == 0 else triangle_patches
291+ if hatch : # draw the hatch using a different color
292+ patches .append (fn (vert , tuple (color ), linewidth = 1 , edgecolor = hatch_color , hatch = hatch ))
293+ patches .append (fn (vert , tuple (color ), linewidth = 1 , edgecolor = edgecolor , fill = False ))
294+ else :
295+ patches .append (fn (vert , tuple (color ), linewidth = 1 , edgecolor = edgecolor ))
296+
297+ for p in triangle_patches + circle_patches :
298+ # need to draw circles on top of triangles
299+ ax .add_patch (p )
306300
307301
308302def plot_gridworld_reward (
309303 state_action_reward : np .ndarray ,
304+ discount : float = 0.99 ,
310305 from_dest : bool = False ,
311306 cbar_format : str = "%.0f" ,
312307 cbar_fraction : float = 0.05 ,
308+ vmin : Optional [float ] = None ,
309+ vmax : Optional [float ] = None ,
313310) -> plt .Figure :
314311 """
315312 Plots a heatmap of reward for the gridworld.
@@ -330,7 +327,7 @@ def plot_gridworld_reward(
330327 xlen , ylen , num_actions = state_action_reward .shape
331328 assert num_actions == len (ACTION_DELTA )
332329 fig , ax = _reward_make_fig (xlen , ylen )
333- mappable = _reward_make_color_map (state_action_reward )
334- _reward_draw (state_action_reward , fig , ax , mappable , from_dest )
330+ mappable = _reward_make_color_map (state_action_reward , vmin , vmax )
331+ _reward_draw (state_action_reward , discount , fig , ax , mappable , from_dest )
335332 fig .colorbar (mappable , format = cbar_format , fraction = cbar_fraction )
336333 return fig
0 commit comments