Skip to content

Commit 89dcc75

Browse files
authored
Gridworld distance: restrict to physically realistic transitions (#12)
* Divergence heatmap: label axes with reward type * Gridworld divergence: restrict to physically realistic for consistency with PointMass * Hardcode axis labels so script works without TeX
1 parent 4c3db3a commit 89dcc75

5 files changed

Lines changed: 55 additions & 12 deletions

File tree

src/evaluating_rewards/analysis/gridworld_heatmap.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,9 @@ def _make_transitions(
114114
transitions[high_action, states[idx < n - 1], states[idx > 0]] = 1
115115

116116

117-
def build_mdp(state_action_reward: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
118-
"""Create transition matrix for deterministic gridworld and reshape reward."""
119-
xlen, ylen, na = state_action_reward.shape
117+
def build_transitions(xlen: int, ylen: int, na: int) -> np.ndarray:
118+
"""Create transition matrix for deterministic gridworld."""
120119
ns = xlen * ylen
121-
122120
transitions = np.zeros((na, ns, ns))
123121
transitions[Actions.STAY.value, :, :] = np.eye(ns, ns)
124122
states = np.arange(ns)
@@ -127,8 +125,16 @@ def build_mdp(state_action_reward: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
127125
_make_transitions(transitions, Actions.LEFT.value, Actions.RIGHT.value, states, ys, ylen)
128126
_make_transitions(transitions, Actions.DOWN.value, Actions.UP.value, states, xs, xlen)
129127

128+
return transitions
129+
130+
131+
def build_reward(state_action_reward: np.ndarray) -> np.ndarray:
132+
"""Reshape reward and fill in NaNs."""
133+
xlen, ylen, na = state_action_reward.shape
134+
ns = xlen * ylen
130135
reward = state_action_reward.copy()
131136
reward = reward.reshape(ns, na)
137+
132138
# We use NaN for transitions that would go outside the gridworld.
133139
# But in above transition dynamics these are equivalent to stay, so rewrite.
134140
mask = np.isnan(reward)
@@ -137,7 +143,7 @@ def build_mdp(state_action_reward: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
137143
reward[mask] = stay_tiled[mask]
138144
assert np.isfinite(reward).all()
139145

140-
return transitions, reward
146+
return reward
141147

142148

143149
def _no_op_iter(*args, **kwargs):
@@ -147,7 +153,8 @@ def _no_op_iter(*args, **kwargs):
147153

148154
def compute_qvalues(state_action_reward: np.ndarray, discount: float) -> np.ndarray:
149155
"""Computes the Q-values of `state_action_reward` under deterministic dynamics."""
150-
transitions, reward = build_mdp(state_action_reward)
156+
transitions = build_transitions(*state_action_reward.shape)
157+
reward = build_reward(state_action_reward)
151158

152159
# TODO(adam): remove this workaround once GH pymdptoolbox #32 merged.
153160
with mock.patch("mdptoolbox.mdp.ValueIteration._boundIter", new=_no_op_iter):

src/evaluating_rewards/analysis/latex/figsymbols.sty

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@
2222
\newcommand{\sparsepenalty}{\texttt{Penalty}}
2323
\newcommand{\centergoal}{\texttt{Center}}
2424
\newcommand{\dirtpath}{\texttt{Path}}
25-
\newcommand{\cliffwalk}{\texttt{Cliff}}
25+
\newcommand{\cliffwalk}{\texttt{Cliff}}
26+
27+
\newcommand{\srcreward}{R_S}
28+
\newcommand{\targetreward}{R_T}

src/evaluating_rewards/analysis/plot_gridworld_divergence.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sacred
2626

2727
from evaluating_rewards import serialize, tabular
28-
from evaluating_rewards.analysis import gridworld_rewards, stylesheets, visualize
28+
from evaluating_rewards.analysis import gridworld_heatmap, gridworld_rewards, stylesheets, visualize
2929
from evaluating_rewards.scripts import script_utils
3030

3131
plot_gridworld_divergence_ex = sacred.Experiment("plot_gridworld_divergence")
@@ -119,6 +119,32 @@ def make_reward(cfg: Dict[str, np.ndarray], discount: float) -> np.ndarray:
119119
return tabular.shape(state_reward, potential, discount)
120120

121121

122+
def direct_divergence(source: np.ndarray, target: np.ndarray, xlen: int, ylen: int) -> float:
123+
"""Computes direct divergence between `source` and `target`.
124+
125+
Args:
126+
source: the source reward.
127+
target: the target reward.
128+
xlen: width of gridworld.
129+
ylen: height of gridworld.
130+
131+
Returns:
132+
Direct divergence of `source` to `target`, under squared-error metric and uniform
133+
random transition dataset. Specifically, dataset generated by sampling state s and
134+
action a uniformly at random, and then deterministically computing the next state s'.
135+
(We could include physically unattainable transitions, but this would be inconsistent
136+
with the PointMass experiments, and is not possible in most environments.)
137+
"""
138+
ns, na, ns2 = source.shape
139+
assert ns == xlen * ylen
140+
assert ns == ns2
141+
transitions = gridworld_heatmap.build_transitions(xlen, ylen, na).transpose((1, 0, 2))
142+
# Zero-out any physically unrealistic rewards in both functions
143+
source = source * transitions
144+
target = target * transitions
145+
return tabular.direct_sq_divergence(source, target)
146+
147+
122148
def compute_divergence(reward_cfg: Dict[str, Any], discount: float) -> pd.Series:
123149
"""Compute divergence for each pair of rewards in `reward_cfg`."""
124150
rewards = {name: make_reward(cfg, discount) for name, cfg in reward_cfg.items()}
@@ -130,7 +156,8 @@ def compute_divergence(reward_cfg: Dict[str, Any], discount: float) -> pd.Series
130156
closest_reward = tabular.closest_reward_em(
131157
src_reward, target_reward, n_iter=1000, discount=discount
132158
)
133-
div = tabular.direct_sq_divergence(closest_reward, target_reward)
159+
xlen, ylen = reward_cfg[src_name]["state_reward"].shape
160+
div = direct_divergence(closest_reward, target_reward, xlen, ylen)
134161
divergence[target_name][src_name] = div
135162
divergence = pd.DataFrame(divergence)
136163
divergence = divergence.stack()

src/evaluating_rewards/analysis/stylesheets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"figure.figsize": (3.25, 2.4375),
4141
"figure.subplot.top": 0.99,
4242
"figure.subplot.bottom": 0.16,
43-
"figure.subplot.left": 0.16,
43+
"figure.subplot.left": 0.17,
4444
"figure.subplot.right": 0.91,
4545
},
4646
"gridworld-heatmap": {

src/evaluating_rewards/analysis/visualize.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def _heatmap_reformat(series, preserve_order):
181181

182182
def comparison_heatmap(
183183
vals: pd.Series,
184+
ax: plt.Axes,
184185
log: bool = True,
185186
fmt: Callable[[float], str] = short_e,
186187
cbar_kws: Optional[Dict[str, Any]] = None,
@@ -224,7 +225,12 @@ def comparison_heatmap(
224225
if robust:
225226
flat = data.values.flatten()
226227
kwargs["vmin"], kwargs["vmax"] = np.quantile(flat, [0.25, 0.75])
227-
sns.heatmap(data, annot=annot, fmt="s", cmap=cmap, cbar_kws=cbar_kws, mask=mask, **kwargs)
228+
sns.heatmap(
229+
data, annot=annot, fmt="s", cmap=cmap, cbar_kws=cbar_kws, mask=mask, ax=ax, **kwargs
230+
)
231+
232+
ax.set_xlabel(r"Target $R_T$")
233+
ax.set_ylabel(r"Source $R_S$")
228234

229235

230236
def median_seeds(series: pd.Series) -> pd.Series:
@@ -388,7 +394,7 @@ def compact_heatmaps(
388394
for name, matching in masks.items():
389395
fig, ax = plt.subplots(1, 1, squeeze=True)
390396
match_mask = compute_mask(loss, matching)
391-
comparison_heatmap(loss, fmt=fmt, preserve_order=True, mask=match_mask, ax=ax, **kwargs)
397+
comparison_heatmap(loss, ax=ax, fmt=fmt, preserve_order=True, mask=match_mask, **kwargs)
392398
# make room for multi-line xlabels
393399
after_plot()
394400
figs[name] = fig

0 commit comments

Comments
 (0)