Skip to content

Commit a71f9db

Browse files
AdamGleavearaffin
authored andcommitted
Make VecNormalize pickleable (#525)
* Make VecNormalize pickleable * Docstrings and load/save methods * Test serializing VecNormalize * Bugfix in tests * Fix lint errors * VecNormalize: make venv mandatory * Update example in documentation with new VecNormalize save routine
1 parent 2de0bb6 commit a71f9db

4 files changed

Lines changed: 119 additions & 14 deletions

File tree

docs/guide/examples.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ will compute a running average and standard deviation of input features (it can
282282
model = PPO2(MlpPolicy, env)
283283
model.learn(total_timesteps=2000)
284284
285-
# Don't forget to save the running average when saving the agent
285+
# Don't forget to save the VecNormalize statistics when saving the agent
286286
log_dir = "/tmp/"
287287
model.save(log_dir + "ppo_reacher")
288-
env.save_running_average(log_dir)
288+
env.save(os.path.join(log_dir, "vec_normalize.pkl"))
289289
290290
291291
Custom Policy Network

docs/misc/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Breaking Changes:
1717
New Features:
1818
^^^^^^^^^^^^^
1919
- Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow
20+
- `VecNormalize` now supports being pickled and unpickled.
2021

2122
Bug Fixes:
2223
^^^^^^^^^^
@@ -28,6 +29,7 @@ Deprecations:
2829
^^^^^^^^^^^^^
2930
- `nprocs` (ACKTR) and `num_procs` (ACER) are deprecated in favor of `n_cpu_tf_sess` which is now common
3031
to all algorithms
32+
- `VecNormalize`: `load_running_average` and `save_running_average` are deprecated in favour of using pickle.
3133

3234
Others:
3335
^^^^^^^

stable_baselines/common/vec_env/vec_normalize.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pickle
2+
import warnings
23

34
import numpy as np
45

@@ -9,7 +10,10 @@
910
class VecNormalize(VecEnvWrapper):
1011
"""
1112
A moving average, normalizing wrapper for vectorized environment.
12-
has support for saving/loading moving average,
13+
14+
It is pickleable which will save moving averages and configuration parameters.
15+
The wrapped environment `venv` is not saved, and must be restored manually with
16+
`set_venv` after being unpickled.
1317
1418
:param venv: (VecEnv) the vectorized environment to wrap
1519
:param training: (bool) Whether to update or not the moving average
@@ -37,6 +41,45 @@ def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
3741
self.norm_reward = norm_reward
3842
self.old_obs = np.array([])
3943

44+
def __getstate__(self):
45+
"""
46+
Gets state for pickling.
47+
48+
Excludes self.venv, as in general VecEnv's may not be pickleable."""
49+
state = self.__dict__.copy()
50+
# these attributes are not pickleable
51+
del state['venv']
52+
del state['class_attributes']
53+
# these attributes depend on the above and so we would prefer not to pickle
54+
del state['ret']
55+
return state
56+
57+
def __setstate__(self, state):
58+
"""
59+
Restores pickled state.
60+
61+
User must call set_venv() after unpickling before using.
62+
63+
:param state: (dict)"""
64+
self.__dict__.update(state)
65+
assert 'venv' not in state
66+
self.venv = None
67+
68+
def set_venv(self, venv):
69+
"""
70+
Sets the vector environment to wrap to venv.
71+
72+
Also sets attributes derived from this such as `num_env`.
73+
74+
:param venv: (VecEnv)
75+
"""
76+
if self.venv is not None:
77+
raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.")
78+
VecEnvWrapper.__init__(self, venv)
79+
if self.obs_rms.mean.shape != self.observation_space.shape:
80+
raise ValueError("venv is incompatible with current statistics.")
81+
self.ret = np.zeros(self.num_envs)
82+
4083
def step_wait(self):
4184
"""
4285
Apply sequence of actions to sequence of environments
@@ -88,18 +131,46 @@ def reset(self):
88131
self.ret = np.zeros(self.num_envs)
89132
return self._normalize_observation(obs)
90133

134+
@staticmethod
135+
def load(load_path, venv):
136+
"""
137+
Loads a saved VecNormalize object.
138+
139+
:param load_path: the path to load from.
140+
:param venv: the VecEnv to wrap.
141+
:return: (VecNormalize)
142+
"""
143+
with open(load_path, "rb") as file_handler:
144+
vec_normalize = pickle.load(file_handler)
145+
vec_normalize.set_venv(venv)
146+
return vec_normalize
147+
148+
def save(self, save_path):
149+
with open(save_path, "wb") as file_handler:
150+
pickle.dump(self, file_handler)
151+
91152
def save_running_average(self, path):
92153
"""
93154
:param path: (str) path to log dir
155+
156+
.. deprecated:: 2.9.0
157+
This function will be removed in a future version
94158
"""
159+
warnings.warn("Usage of `save_running_average` is deprecated. Please "
160+
"use `save` or pickle instead.", DeprecationWarning)
95161
for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
96162
with open("{}/{}.pkl".format(path, name), 'wb') as file_handler:
97163
pickle.dump(rms, file_handler)
98164

99165
def load_running_average(self, path):
100166
"""
101167
:param path: (str) path to log dir
168+
169+
.. deprecated:: 2.9.0
170+
This function will be removed in a future version
102171
"""
172+
warnings.warn("Usage of `load_running_average` is deprecated. Please "
173+
"use `load` or pickle instead.", DeprecationWarning)
103174
for name in ['obs_rms', 'ret_rms']:
104175
with open("{}/{}.pkl".format(path, name), 'rb') as file_handler:
105176
setattr(self, name, pickle.load(file_handler))

tests/test_vec_normalize.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
ENV_ID = 'Pendulum-v0'
1212

1313

14+
def make_env():
15+
return gym.make(ENV_ID)
16+
17+
1418
def test_runningmeanstd():
1519
"""Test RunningMeanStd object"""
1620
for (x_1, x_2, x_3) in [
@@ -28,20 +32,48 @@ def test_runningmeanstd():
2832
assert np.allclose(moments_1, moments_2)
2933

3034

31-
def test_vec_env():
32-
"""Test VecNormalize Object"""
35+
def check_rms_equal(rmsa, rmsb):
36+
assert np.all(rmsa.mean == rmsb.mean)
37+
assert np.all(rmsa.var == rmsb.var)
38+
assert np.all(rmsa.count == rmsb.count)
39+
40+
41+
def check_vec_norm_equal(norma, normb):
42+
assert norma.observation_space == normb.observation_space
43+
assert norma.action_space == normb.action_space
44+
assert norma.num_envs == normb.num_envs
3345

34-
def make_env():
35-
return gym.make(ENV_ID)
46+
check_rms_equal(norma.obs_rms, normb.obs_rms)
47+
check_rms_equal(norma.ret_rms, normb.ret_rms)
48+
assert norma.clip_obs == normb.clip_obs
49+
assert norma.clip_reward == normb.clip_reward
50+
assert norma.norm_obs == normb.norm_obs
51+
assert norma.norm_reward == normb.norm_reward
3652

37-
env = DummyVecEnv([make_env])
38-
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
39-
_, done = env.reset(), [False]
40-
obs = None
53+
assert np.all(norma.ret == normb.ret)
54+
assert norma.gamma == normb.gamma
55+
assert norma.epsilon == normb.epsilon
56+
assert norma.training == normb.training
57+
58+
59+
def test_vec_env(tmpdir):
60+
"""Test VecNormalize Object"""
61+
clip_obs = 0.5
62+
clip_reward = 5.0
63+
64+
orig_venv = DummyVecEnv([make_env])
65+
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
66+
_, done = norm_venv.reset(), [False]
4167
while not done[0]:
42-
actions = [env.action_space.sample()]
43-
obs, _, done, _ = env.step(actions)
44-
assert np.max(obs) <= 10
68+
actions = [norm_venv.action_space.sample()]
69+
obs, rew, done, _ = norm_venv.step(actions)
70+
assert np.max(np.abs(obs)) <= clip_obs
71+
assert np.max(np.abs(rew)) <= clip_reward
72+
73+
path = str(tmpdir.join("vec_normalize"))
74+
norm_venv.save(path)
75+
deserialized = VecNormalize.load(path, venv=orig_venv)
76+
check_vec_norm_equal(norm_venv, deserialized)
4577

4678

4779
def test_mpi_runningmeanstd():

0 commit comments

Comments
 (0)