Skip to content

Commit 24d90b2

Browse files
authored
Improve PointMass Reward Heatmap Script (#9)
* Change labels in divergence heatmaps * Have plot_pm_reward use stylesheets too * Fix MPL import and backend * Change figure label positions, tweak parameters * Fix docs * Bugfix: do not delete TEXINPUTS if never existed * Avoid unnecessary TeX * Fix TeX and os.environ for reals
1 parent 1fdf9f2 commit 24d90b2

6 files changed

Lines changed: 172 additions & 76 deletions

File tree

src/evaluating_rewards/analysis/latex/figemojis.sty renamed to src/evaluating_rewards/analysis/latex/figsymbols.sty

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@
1010
\newcommand{\backward}[1]{\reflectbox{#1}}
1111

1212
\newcommand{\controlpenalty}{\reflectbox{\emoji{noto_snail}}}
13-
\newcommand{\nocontrolpenalty}{\reflectbox{\emoji{mozilla_cheetah}}}
13+
\newcommand{\nocontrolpenalty}{\reflectbox{\emoji{mozilla_cheetah}}}
14+
15+
\newcommand{\sparse}{\texttt{S}}
16+
\newcommand{\dense}{\texttt{D}}
17+
\newcommand{\magnitude}{\texttt{M}}
18+
\newcommand{\zeroreward}{\texttt{Zero}}

src/evaluating_rewards/analysis/plot_divergence_heatmap.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Iterable, Mapping, Optional
2020

2121
from imitation import util
22+
import matplotlib.pyplot as plt
2223
import sacred
2324

2425
from evaluating_rewards import serialize
@@ -30,8 +31,6 @@
3031

3132
def horizontal_ticks() -> None:
3233
# lazy import to allow custom backend
33-
import matplotlib.pyplot as plt # pylint:disable=import-outside-toplevel
34-
3534
plt.xticks(rotation="horizontal")
3635
plt.yticks(rotation="horizontal")
3736

@@ -87,6 +86,18 @@ def fast():
8786
del _
8887

8988

89+
@plot_divergence_heatmap_ex.named_config
90+
def dataset_transition():
91+
"""Searches for comparisons using `random_transition_generator`."""
92+
search = { # noqa: F841 pylint:disable=unused-variable
93+
"dataset_factory": {
94+
"escape/py/function": (
95+
"evaluating_rewards.experiments.datasets.random_transition_generator"
96+
),
97+
},
98+
}
99+
100+
90101
def _norm(args: Iterable[str]) -> bool:
91102
return any(visualize.match("evaluating_rewards/PointMassGroundTruth-v0")(args))
92103

@@ -97,7 +108,6 @@ def point_mass():
97108
search = { # noqa: F841 pylint:disable=unused-variable
98109
"env_name": "evaluating_rewards/PointMassLine-v0",
99110
"dataset_factory": {
100-
# can also use evaluating_rewards.experiments.datasets.random_transition_generator
101111
"escape/py/function": "evaluating_rewards.experiments.datasets.random_policy_generator",
102112
},
103113
}
@@ -109,7 +119,7 @@ def point_mass():
109119
"norm": [visualize.zero, visualize.same, _norm],
110120
"all": [visualize.always_true],
111121
}
112-
order = ["SparseNoCtrl", "Sparse", "DenseNoCtrl", "Dense", "GroundTruth"]
122+
order = ["SparseNoCtrl", "SparseWithCtrl", "DenseNoCtrl", "DenseWithCtrl", "GroundTruth"]
113123
heatmap_kwargs["order"] = [f"evaluating_rewards/PointMass{label}-v0" for label in order]
114124
heatmap_kwargs["after_plot"] = horizontal_ticks
115125
del order
@@ -210,16 +220,7 @@ def plot_divergence_heatmap(
210220
log_dir: directory to write figures and other logging to.
211221
save_kwargs: passed through to `analysis.save_figs`.
212222
"""
213-
if "tex" in styles:
214-
import matplotlib # pylint:disable=import-outside-toplevel
215-
216-
matplotlib.use("pgf") # PGF backend best for LaTeX
217-
os.environ["TEXINPUTS"] = stylesheets.LATEX_DIR + ":"
218-
styles = [stylesheets.STYLES[style] for style in styles]
219-
220-
import matplotlib.pyplot as plt # pylint:disable=import-outside-toplevel
221-
222-
with plt.style.context(styles):
223+
with stylesheets.setup_styles(styles):
223224
data_dir = data_root
224225
if data_subdir is not None:
225226
data_dir = os.path.join(data_dir, data_subdir)

src/evaluating_rewards/analysis/plot_pm_reward.py

Lines changed: 48 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
import os
21-
from typing import Any, Mapping, Sequence, Tuple
21+
from typing import Any, Iterable, Mapping, Sequence, Tuple
2222

2323
import gym
2424
from imitation import util
@@ -28,7 +28,7 @@
2828
import xarray as xr
2929

3030
from evaluating_rewards import serialize
31-
from evaluating_rewards.analysis import visualize
31+
from evaluating_rewards.analysis import stylesheets, visualize
3232
from evaluating_rewards.experiments import point_mass_analysis
3333
from evaluating_rewards.scripts import script_utils
3434

@@ -55,10 +55,9 @@ def default_config():
5555
act_lim = lim # action point range
5656

5757
# Figure parameters
58+
styles = ["paper", "pointmass-2col", "tex"]
5859
ncols = 3 # number of heatmaps per row
59-
width = 5 # in
60-
height = 4 # in
61-
cbar_kwargs = {"fraction": 0.1, "pad": 0.05}
60+
cbar_kwargs = {"fraction": 0.07, "pad": 0.02}
6261
fmt = "pdf" # file type
6362
_ = locals() # quieten flake8 unused variable warning
6463
del _
@@ -69,12 +68,11 @@ def default_config():
6968

7069
@plot_pm_reward_ex.config
7170
def logging_config(log_root, models, reward_type, reward_path):
71+
data_root = os.path.join(log_root, "model_comparison")
7272
if models is None:
73-
save_path = os.path.join(
73+
log_dir = os.path.join(
7474
log_root, reward_type.replace("/", "_"), reward_path.replace("/", "_")
7575
)
76-
else:
77-
save_path = util.make_unique_timestamp()
7876
_ = locals() # quieten flake8 unused variable warning
7977
del _
8078

@@ -87,7 +85,7 @@ def reward_config(models, reward_type, reward_path):
8785
del _
8886

8987

90-
STRIP_CONFIG = dict(pos_density=7, ncols=7, width=9.5, height=1.5)
88+
STRIP_CONFIG = dict(pos_density=7, ncols=7)
9189

9290

9391
@plot_pm_reward_ex.named_config
@@ -100,7 +98,6 @@ def strip():
10098
def dense_no_ctrl_sparsified():
10199
"""PointMassDenseNoCtrl along with sparsified and ground-truth sparse reward."""
102100
locals().update(**STRIP_CONFIG)
103-
height = 4.5
104101
pos_lim = 0.15
105102
# Use lists of tuples rather than OrderedDict as Sacred reorders dictionaries
106103
models = [
@@ -109,8 +106,6 @@ def dense_no_ctrl_sparsified():
109106
"Sparsified",
110107
"evaluating_rewards/RewardModel-v0",
111108
os.path.join(
112-
serialize.get_output_dir(),
113-
"model_comparison",
114109
"evaluating_rewards_PointMassLine-v0",
115110
"20190921_190606_58935eb0a51849508381daf1055d0360",
116111
"model",
@@ -125,13 +120,18 @@ def dense_no_ctrl_sparsified():
125120
@plot_pm_reward_ex.named_config
126121
def fast():
127122
"""Small config, intended for tests / debugging."""
128-
density = 5 # noqa: F841 pylint:disable=unused-variable
123+
density = 5
124+
styles = ["paper", "pointmass-2col"] # don't use TeX for tests
125+
_ = locals()
126+
del _
129127

130128

131129
@plot_pm_reward_ex.main
132130
def plot_pm_reward(
131+
styles: Iterable[str],
133132
env_name: str,
134133
models: Sequence[Tuple[str, str, str]],
134+
data_root: str,
135135
# Mesh parameters
136136
pos_lim: float,
137137
pos_density: int,
@@ -140,46 +140,45 @@ def plot_pm_reward(
140140
density: int,
141141
# Figure parameters
142142
ncols: int,
143-
width: float,
144-
height: float,
145143
cbar_kwargs: Mapping[str, Any],
146-
save_path: str,
144+
log_dir: str,
147145
fmt: str,
148146
) -> xr.DataArray:
149147
"""Entry-point into script to visualize a reward model for point mass."""
150-
env = gym.make(env_name)
151-
venv = vec_env.DummyVecEnv([lambda: env])
152-
goal = np.array([0.0])
153-
154-
rewards = {}
155-
with util.make_session():
156-
for model_name, reward_type, reward_path in models:
157-
model = serialize.load_reward(reward_type, reward_path, venv)
158-
reward = point_mass_analysis.evaluate_reward_model(
159-
env,
160-
model,
161-
goal=goal,
162-
pos_lim=pos_lim,
163-
pos_density=pos_density,
164-
vel_lim=vel_lim,
165-
act_lim=act_lim,
166-
density=density,
167-
)
168-
rewards[model_name] = reward
169-
170-
if len(rewards) == 1:
171-
reward = next(iter(rewards.values()))
172-
kwargs = {"col_wrap": ncols}
173-
else:
174-
reward = xr.Dataset(rewards).to_array("model")
175-
kwargs = {"row": "Model"}
176-
177-
fig = point_mass_analysis.plot_reward(
178-
reward, figsize=(width, height), cbar_kwargs=cbar_kwargs, **kwargs
179-
)
180-
visualize.save_fig(save_path, fig, fmt=fmt)
181-
182-
return reward
148+
with stylesheets.setup_styles(styles):
149+
env = gym.make(env_name)
150+
venv = vec_env.DummyVecEnv([lambda: env])
151+
goal = np.array([0.0])
152+
153+
rewards = {}
154+
with util.make_session():
155+
for model_name, reward_type, reward_path in models:
156+
reward_path = os.path.join(data_root, reward_path)
157+
model = serialize.load_reward(reward_type, reward_path, venv)
158+
reward = point_mass_analysis.evaluate_reward_model(
159+
env,
160+
model,
161+
goal=goal,
162+
pos_lim=pos_lim,
163+
pos_density=pos_density,
164+
vel_lim=vel_lim,
165+
act_lim=act_lim,
166+
density=density,
167+
)
168+
rewards[model_name] = reward
169+
170+
if len(rewards) == 1:
171+
reward = next(iter(rewards.values()))
172+
kwargs = {"col_wrap": ncols}
173+
else:
174+
reward = xr.Dataset(rewards).to_array("model")
175+
kwargs = {"row": "Model"}
176+
177+
fig = point_mass_analysis.plot_reward(reward, cbar_kwargs=cbar_kwargs, **kwargs)
178+
save_path = os.path.join(log_dir, "reward")
179+
visualize.save_fig(save_path, fig, fmt=fmt)
180+
181+
return reward
183182

184183

185184
if __name__ == "__main__":

src/evaluating_rewards/analysis/stylesheets.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""matplotlib styles."""
22

3+
import contextlib
34
import os
5+
from typing import Iterable, Iterator
46

57
LATEX_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latex")
68

@@ -16,6 +18,15 @@
1618
"xtick.labelsize": 10,
1719
"ytick.labelsize": 10,
1820
},
21+
"pointmass-2col": {
22+
"figure.figsize": (6.75, 2.5),
23+
"figure.subplot.left": 0.2,
24+
"figure.subplot.right": 1.0,
25+
"figure.subplot.top": 0.92,
26+
"figure.subplot.bottom": 0.16,
27+
"figure.subplot.hspace": 0.2,
28+
"figure.subplot.wspace": 0.25,
29+
},
1930
"heatmap-2col": {"figure.figsize": (6.75, 5.0625)},
2031
"heatmap-1col": {
2132
"font.size": 8,
@@ -24,14 +35,49 @@
2435
"figure.figsize": (3.25, 2.4375),
2536
"figure.subplot.top": 0.99,
2637
"figure.subplot.bottom": 0.16,
27-
"figure.subplot.left": 0.15,
38+
"figure.subplot.left": 0.16,
2839
"figure.subplot.right": 0.91,
2940
},
3041
"tex": {
31-
"backend": "pgf",
3242
"text.usetex": True,
3343
"pgf.texsystem": "pdflatex",
3444
"pgf.rcfonts": False,
35-
"pgf.preamble": [r"\usepackage{figemojis}", r"\usepackage{times}"],
45+
"pgf.preamble": [r"\usepackage{figsymbols}", r"\usepackage{times}"],
3646
},
3747
}
48+
49+
50+
@contextlib.contextmanager
51+
def setup_styles(styles: Iterable[str]) -> Iterator[None]:
52+
"""Context manager: uses specified matplotlib styles while in context.
53+
54+
Side-effect: if "tex" is in styles, will switch `matplotlib` backend to `pgf`.
55+
56+
Args:
57+
styles: keys of styles defined in `STYLES`.
58+
59+
Returns:
60+
A ContextManager. While entered in the context, the specified styles are applied,
61+
and (if "tex" is one of the styles) the environment variable "TEXINPUTS" is set
62+
to support custom macros."""
63+
old_tex_inputs = os.environ.get("TEXINPUTS")
64+
try:
65+
if "tex" in styles:
66+
import matplotlib # pylint:disable=import-outside-toplevel
67+
68+
# PGF backend best for LaTeX. matplotlib probably already imported:
69+
# but should be able to switch as non-interactive.
70+
matplotlib.use("pgf", warn=False, force=True)
71+
os.environ["TEXINPUTS"] = LATEX_DIR + ":"
72+
styles = [STYLES[style] for style in styles]
73+
74+
import matplotlib.pyplot as plt # pylint:disable=import-outside-toplevel
75+
76+
with plt.style.context(styles):
77+
yield
78+
finally:
79+
if "tex" in styles:
80+
if old_tex_inputs is None:
81+
del os.environ["TEXINPUTS"]
82+
else:
83+
os.environ["TEXINPUTS"] = old_tex_inputs

src/evaluating_rewards/analysis/visualize.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,10 @@
3131
TRANSFORMATIONS = {
3232
r"^evaluating_rewards[_/](.*)-v0": r"\1",
3333
r"^imitation[_/](.*)-v0": r"\1",
34-
"^Zero-v0": "Zero",
35-
"^PointMassDense": "Dense",
36-
"^PointMassDenseNoCtrl": "Dense\nNo Ctrl",
37-
"^PointMassGroundTruth": "Norm",
38-
"^PointMassSparse": "Sparse",
39-
"^PointMassSparseNoCtrl": "Sparse\nNo Ctrl",
34+
"^Zero": r"\\zeroreward{}",
35+
"^PointMassDense": r"\\dense{}",
36+
"^PointMassGroundTruth": r"\\magnitude{}\\controlpenalty{}",
37+
"^PointMassSparse": r"\\sparse{}",
4038
"^PointMazeGroundTruth": "GT",
4139
r"(.*)(Hopper|HalfCheetah)GroundTruth(.*)": r"\1\2\\running{}\3",
4240
r"(.*)(Hopper|HalfCheetah)Backflip(.*)": r"\1\2\\backflipping{}\3",

0 commit comments

Comments
 (0)