|
1 | 1 | import pickle |
| 2 | +import warnings |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 |
|
|
9 | 10 | class VecNormalize(VecEnvWrapper): |
10 | 11 | """ |
11 | 12 | A moving average, normalizing wrapper for vectorized environment. |
12 | | - has support for saving/loading moving average, |
| 13 | +
|
| 14 | + It is pickleable which will save moving averages and configuration parameters. |
| 15 | + The wrapped environment `venv` is not saved, and must be restored manually with |
| 16 | + `set_venv` after being unpickled. |
13 | 17 |
|
14 | 18 | :param venv: (VecEnv) the vectorized environment to wrap |
15 | 19 | :param training: (bool) Whether to update or not the moving average |
@@ -37,6 +41,45 @@ def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, |
37 | 41 | self.norm_reward = norm_reward |
38 | 42 | self.old_obs = np.array([]) |
39 | 43 |
|
| 44 | + def __getstate__(self): |
| 45 | + """ |
| 46 | + Gets state for pickling. |
| 47 | +
|
| 48 | + Excludes self.venv, as in general VecEnv's may not be pickleable.""" |
| 49 | + state = self.__dict__.copy() |
| 50 | + # these attributes are not pickleable |
| 51 | + del state['venv'] |
| 52 | + del state['class_attributes'] |
| 53 | + # these attributes depend on the above and so we would prefer not to pickle |
| 54 | + del state['ret'] |
| 55 | + return state |
| 56 | + |
| 57 | + def __setstate__(self, state): |
| 58 | + """ |
| 59 | + Restores pickled state. |
| 60 | +
|
| 61 | + User must call set_venv() after unpickling before using. |
| 62 | +
|
| 63 | + :param state: (dict)""" |
| 64 | + self.__dict__.update(state) |
| 65 | + assert 'venv' not in state |
| 66 | + self.venv = None |
| 67 | + |
| 68 | + def set_venv(self, venv): |
| 69 | + """ |
| 70 | + Sets the vector environment to wrap to venv. |
| 71 | +
|
| 72 | + Also sets attributes derived from this such as `num_env`. |
| 73 | +
|
| 74 | + :param venv: (VecEnv) |
| 75 | + """ |
| 76 | + if self.venv is not None: |
| 77 | + raise ValueError("Trying to set venv of already initialized VecNormalize wrapper.") |
| 78 | + VecEnvWrapper.__init__(self, venv) |
| 79 | + if self.obs_rms.mean.shape != self.observation_space.shape: |
| 80 | + raise ValueError("venv is incompatible with current statistics.") |
| 81 | + self.ret = np.zeros(self.num_envs) |
| 82 | + |
40 | 83 | def step_wait(self): |
41 | 84 | """ |
42 | 85 | Apply sequence of actions to sequence of environments |
@@ -88,18 +131,46 @@ def reset(self): |
88 | 131 | self.ret = np.zeros(self.num_envs) |
89 | 132 | return self._normalize_observation(obs) |
90 | 133 |
|
| 134 | + @staticmethod |
| 135 | + def load(load_path, venv): |
| 136 | + """ |
| 137 | + Loads a saved VecNormalize object. |
| 138 | +
|
| 139 | + :param load_path: the path to load from. |
| 140 | + :param venv: the VecEnv to wrap. |
| 141 | + :return: (VecNormalize) |
| 142 | + """ |
| 143 | + with open(load_path, "rb") as file_handler: |
| 144 | + vec_normalize = pickle.load(file_handler) |
| 145 | + vec_normalize.set_venv(venv) |
| 146 | + return vec_normalize |
| 147 | + |
| 148 | + def save(self, save_path): |
| 149 | + with open(save_path, "wb") as file_handler: |
| 150 | + pickle.dump(self, file_handler) |
| 151 | + |
91 | 152 | def save_running_average(self, path): |
92 | 153 | """ |
93 | 154 | :param path: (str) path to log dir |
| 155 | +
|
| 156 | + .. deprecated:: 2.9.0 |
| 157 | + This function will be removed in a future version |
94 | 158 | """ |
| 159 | + warnings.warn("Usage of `save_running_average` is deprecated. Please " |
| 160 | + "use `save` or pickle instead.", DeprecationWarning) |
95 | 161 | for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']): |
96 | 162 | with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: |
97 | 163 | pickle.dump(rms, file_handler) |
98 | 164 |
|
99 | 165 | def load_running_average(self, path): |
100 | 166 | """ |
101 | 167 | :param path: (str) path to log dir |
| 168 | +
|
| 169 | + .. deprecated:: 2.9.0 |
| 170 | + This function will be removed in a future version |
102 | 171 | """ |
| 172 | + warnings.warn("Usage of `load_running_average` is deprecated. Please " |
| 173 | + "use `load` or pickle instead.", DeprecationWarning) |
103 | 174 | for name in ['obs_rms', 'ret_rms']: |
104 | 175 | with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: |
105 | 176 | setattr(self, name, pickle.load(file_handler)) |
0 commit comments