From ac3181441c55a48b8435b0b33bbb25f11fb569bc Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 23 Nov 2023 10:06:43 -0500 Subject: [PATCH 1/5] Added PPO LSTM --- pyproject.toml | 2 +- rllte/agent/__init__.py | 1 + rllte/agent/legacy/dqn.py | 2 +- rllte/agent/legacy/ppo.py | 2 +- rllte/agent/ppo_lstm.py | 210 ++++++++++++++ rllte/common/prototype/off_policy_agent.py | 244 ---------------- rllte/common/prototype/on_policy_agent.py | 57 +++- rllte/common/type_alias.py | 11 + rllte/env/__init__.py | 59 ---- rllte/xploit/encoder/mnih_cnn_encoder.py | 2 +- rllte/xploit/policy/__init__.py | 1 + .../on_policy_shared_actor_critic_lstm.py | 270 ++++++++++++++++++ rllte/xploit/storage/__init__.py | 1 + .../storage/episodic_rollout_storage.py | 196 +++++++++++++ 14 files changed, 747 insertions(+), 311 deletions(-) create mode 100644 rllte/agent/ppo_lstm.py delete mode 100644 rllte/common/prototype/off_policy_agent.py delete mode 100644 rllte/env/__init__.py create mode 100644 rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py create mode 100644 rllte/xploit/storage/episodic_rollout_storage.py diff --git a/pyproject.toml b/pyproject.toml index a9213554..17abf814 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "pynvml==11.5.0", "matplotlib==3.6.0", "seaborn==0.12.2", - "huggingface_hub==0.14.1" + "huggingface_hub==0.14.1", ] [project.optional-dependencies] diff --git a/rllte/agent/__init__.py b/rllte/agent/__init__.py index cd75157f..09403c61 100644 --- a/rllte/agent/__init__.py +++ b/rllte/agent/__init__.py @@ -36,3 +36,4 @@ from .legacy.sacd import SACDiscrete as SACDiscrete from .legacy.td3 import TD3 as TD3 from .ppg import PPG as PPG +from .ppo_lstm import PPO_LSTM as PPO_LSTM \ No newline at end of file diff --git a/rllte/agent/legacy/dqn.py b/rllte/agent/legacy/dqn.py index 6326579e..73131d02 100644 --- a/rllte/agent/legacy/dqn.py +++ b/rllte/agent/legacy/dqn.py @@ -190,4 +190,4 @@ def update(self) -> None: # record metrics self.logger.record("train/q_loss", huber_loss.item()) self.logger.record("train/q", q_values.mean().item()) - self.logger.record("train/target_q", target_q_values.mean().item()) + self.logger.record("train/target_q", target_q_values.mean().item()) \ No newline at end of file diff --git a/rllte/agent/legacy/ppo.py b/rllte/agent/legacy/ppo.py index 63a1afc5..2e108167 100644 --- a/rllte/agent/legacy/ppo.py +++ b/rllte/agent/legacy/ppo.py @@ -88,7 +88,7 @@ def __init__( vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float = 0.5, - discount: float = 0.999, + discount: float = 0.99, init_fn: str = "orthogonal", ) -> None: super().__init__( diff --git a/rllte/agent/ppo_lstm.py b/rllte/agent/ppo_lstm.py new file mode 100644 index 00000000..074c154f --- /dev/null +++ b/rllte/agent/ppo_lstm.py @@ -0,0 +1,210 @@ +# ============================================================================= +# MIT License + +# Copyright (c) 2023 Reinforcement Learning Evolution Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + + +from typing import Optional + +import numpy as np +import torch as th +from torch import nn + +from rllte.common.prototype import OnPolicyAgent +from rllte.common.type_alias import VecEnv +from rllte.xploit.encoder import IdentityEncoder, MnihCnnEncoder, EspeholtResidualEncoder, PathakCnnEncoder +from rllte.xploit.policy import OnPolicySharedActorCriticLSTM +from rllte.xploit.storage import EpisodicRolloutStorage +from rllte.xplore.distribution import Bernoulli, Categorical, DiagonalGaussian, MultiCategorical + + +class PPO_LSTM(OnPolicyAgent): + """Proximal Policy Optimization (PPO) with LSTM agent. + Based on: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py + + Args: + env (VecEnv): Vectorized environments for training. + eval_env (VecEnv): Vectorized environments for evaluation. + tag (str): An experiment tag. + seed (int): Random seed for reproduction. + device (str): Device (cpu, cuda, ...) on which the code should be run. + pretraining (bool): Turn on the pre-training mode. + num_steps (int): The sample length of per rollout. + feature_dim (int): Number of features extracted by the encoder. + batch_size (int): Number of samples per batch to load. + lr (float): The learning rate. + eps (float): Term added to the denominator to improve numerical stability. + hidden_dim (int): The size of the hidden layers. + clip_range (float): Clipping parameter. + clip_range_vf (Optional[float]): Clipping parameter for the value function. + n_epochs (int): Times of updating the policy. + vf_coef (float): Weighting coefficient of value loss. + ent_coef (float): Weighting coefficient of entropy bonus. + max_grad_norm (float): Maximum norm of gradients. + discount (float): Discount factor. + init_fn (str): Parameters initialization method. + + Returns: + PPO_LSTM agent instance. + """ + + def __init__( + self, + env: VecEnv, + eval_env: Optional[VecEnv] = None, + tag: str = "default", + seed: int = 1, + device: str = "cpu", + pretraining: bool = False, + num_steps: int = 128, + feature_dim: int = 512, + batch_size: int = 256, + lr: float = 2.5e-4, + eps: float = 1e-5, + hidden_dim: int = 512, + clip_range: float = 0.1, + clip_range_vf: Optional[float] = 0.1, + n_epochs: int = 4, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: float = 0.5, + discount: float = 0.999, + init_fn: str = "orthogonal", + num_batches: int = 4, + ) -> None: + super().__init__( + env=env, + eval_env=eval_env, + tag=tag, + seed=seed, + device=device, + pretraining=pretraining, + num_steps=num_steps, + use_lstm=True, + ) + + # hyper parameters + self.lr = lr + self.eps = eps + self.n_epochs = n_epochs + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + self.vf_coef = vf_coef + self.ent_coef = ent_coef + self.max_grad_norm = max_grad_norm + + # default encoder + if len(self.obs_shape) == 3: + encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=feature_dim) + elif len(self.obs_shape) == 1: + feature_dim = self.obs_shape[0] # type: ignore + encoder = IdentityEncoder( + observation_space=env.observation_space, feature_dim=feature_dim # type: ignore[assignment] + ) + + # default distribution + if self.action_type == "Discrete": + dist = Categorical() + elif self.action_type == "Box": + dist = DiagonalGaussian() # type: ignore[assignment] + elif self.action_type == "MultiBinary": + dist = Bernoulli() # type: ignore[assignment] + elif self.action_type == "MultiDiscrete": + dist = MultiCategorical() # type: ignore[assignment] + else: + raise NotImplementedError(f"Unsupported action type {self.action_type}!") + + # create policy + policy = OnPolicySharedActorCriticLSTM( + observation_space=env.observation_space, + action_space=env.action_space, + feature_dim=feature_dim, + hidden_dim=hidden_dim, + opt_class=th.optim.Adam, + opt_kwargs=dict(lr=lr, eps=eps), + init_fn=init_fn, + ) + + # default storage + storage = EpisodicRolloutStorage( + observation_space=env.observation_space, + action_space=env.action_space, + device=device, + storage_size=self.num_steps, + num_envs=self.num_envs, + discount=discount, + num_batches=num_batches, + ) + + # set all the modules [essential operation!!!] + self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist) + + def update(self) -> None: + """Update function that returns training metrics such as policy loss, value loss, etc..""" + total_policy_loss = [0.0] + total_value_loss = [0.0] + total_entropy_loss = [0.0] + + for _ in range(self.n_epochs): + for batch in self.storage.sample(): + done = th.logical_or(batch.terminateds, batch.truncateds) + + # evaluate sampled actions + new_values, new_log_probs, entropy = self.policy.evaluate_actions( + obs=batch.observations, + actions=batch.actions, + lstm_state=(self.initial_lstm_state[0][:, batch.env_inds], self.initial_lstm_state[1][:, batch.env_inds]), + done=done + ) + + # policy loss part + ratio = th.exp(new_log_probs - batch.old_log_probs) + surr1 = ratio * batch.adv_targ + surr2 = th.clamp(ratio, 1.0 - self.clip_range, 1.0 + self.clip_range) * batch.adv_targ + policy_loss = -th.min(surr1, surr2).mean() + + # value loss part + if self.clip_range_vf is None: + value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean() + else: + values_clipped = batch.values + (new_values.flatten() - batch.values).clamp( + -self.clip_range_vf, self.clip_range_vf + ) + values_losses = (new_values.flatten() - batch.returns).pow(2) + values_losses_clipped = (values_clipped - batch.returns).pow(2) + value_loss = 0.5 * th.max(values_losses, values_losses_clipped).mean() + + # update + self.policy.optimizers["opt"].zero_grad(set_to_none=True) + loss = value_loss * self.vf_coef + policy_loss - entropy * self.ent_coef + loss.backward() + nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizers["opt"].step() + + total_policy_loss.append(policy_loss.item()) + total_value_loss.append(value_loss.item()) + total_entropy_loss.append(entropy.item()) + + # record metrics + self.logger.record("train/policy_loss", np.mean(total_policy_loss)) + self.logger.record("train/value_loss", np.mean(total_value_loss)) + self.logger.record("train/entropy_loss", np.mean(total_entropy_loss)) diff --git a/rllte/common/prototype/off_policy_agent.py b/rllte/common/prototype/off_policy_agent.py deleted file mode 100644 index 06c97055..00000000 --- a/rllte/common/prototype/off_policy_agent.py +++ /dev/null @@ -1,244 +0,0 @@ -# ============================================================================= -# MIT License - -# Copyright (c) 2023 Reinforcement Learning Evolution Foundation - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ============================================================================= - - -from collections import deque -from copy import deepcopy -from typing import Any, Deque, Dict, List, Optional - -import numpy as np -import torch as th - -from rllte.common import utils -from rllte.common.prototype.base_agent import BaseAgent -from rllte.common.type_alias import OffPolicyType, ReplayStorageType, VecEnv - - -class OffPolicyAgent(BaseAgent): - """Trainer for off-policy algorithms. - - Args: - env (VecEnv): Vectorized environments for training. - eval_env (Optional[VecEnv]): Vectorized environments for evaluation. - tag (str): An experiment tag. - seed (int): Random seed for reproduction. - device (str): Device (cpu, cuda, ...) on which the code should be run. - pretraining (bool): Turn on pre-training model or not. - num_init_steps (int): Number of initial exploration steps. - **kwargs: Arbitrary arguments such as `batch_size` and `hidden_dim`. - - Returns: - Off-policy agent instance. - """ - - def __init__( - self, - env: VecEnv, - eval_env: Optional[VecEnv] = None, - tag: str = "default", - seed: int = 1, - device: str = "cpu", - pretraining: bool = False, - num_init_steps: int = 2000, - **kwargs, - ) -> None: - super().__init__(env=env, eval_env=eval_env, tag=tag, seed=seed, device=device, pretraining=pretraining) - self.num_init_steps = num_init_steps - # attr annotations - self.policy: OffPolicyType - self.storage: ReplayStorageType - - def update(self) -> None: - """Update the agent. Implemented by individual algorithms.""" - raise NotImplementedError - - def train( # noqa: C901 - self, - num_train_steps: int, - init_model_path: Optional[str] = None, - log_interval: int = 1, - eval_interval: int = 5000, - save_interval: int = 5000, - num_eval_episodes: int = 10, - th_compile: bool = False, - anneal_lr: bool = False - ) -> None: - """Training function. - - Args: - num_train_steps (int): The number of training steps. - init_model_path (Optional[str]): The path of the initial model. - log_interval (int): The interval of logging. - eval_interval (int): The interval of evaluation. - save_interval (int): The interval of saving model. - num_eval_episodes (int): The number of evaluation episodes. - th_compile (bool): Whether to use `th.compile` or not. - anneal_lr (bool): Whether to anneal the learning rate or not. - - Returns: - None. - """ - # freeze the agent and get ready for training - self.freeze(init_model_path=init_model_path, th_compile=th_compile) - - # reset the env - episode_rewards: Deque = deque(maxlen=10) - episode_steps: Deque = deque(maxlen=10) - obs, infos = self.env.reset(seed=self.seed) - - # training loop - while self.global_step < num_train_steps: - # try to eval - if (self.global_step % eval_interval) == 0 and (self.eval_env is not None): - eval_metrics = self.eval(num_eval_episodes) - - # log to console - self.logger.eval(msg=eval_metrics) - - # sample actions - with th.no_grad(), utils.eval_mode(self): - # Initial exploration - if self.global_step < self.num_init_steps: - actions = th.stack([th.as_tensor(self.action_space.sample()) for _ in range(self.num_envs)]) - else: - actions = self.policy(obs, training=True) - - # update the learning rate - if anneal_lr: - for key in self.policy.optimizers.keys(): - utils.linear_lr_scheduler(self.policy.optimizers[key], self.global_step, num_train_steps, self.lr) - - # update agent - if self.global_step >= self.num_init_steps: - self.update() - # try to update storage - self.storage.update(self.metrics) - - # observe reward and next obs - next_obs, rews, terms, truncs, infos = self.env.step(actions) - - # pre-training mode - if self.pretraining: - rews = th.zeros_like(rews, device=self.device) - - # TODO: get real next observations - # As the vector environments autoreset for a terminating and truncating sub-environments, - # the returned observation and info is not the final step's observation or info which - # is instead stored in info as `final_observation` and `final_info`. So we need to get - # the real next observations from the infos and not to reset the environments. - real_next_obs = deepcopy(next_obs) - for idx, (term, trunc) in enumerate(zip(terms, truncs)): - if term.item() or trunc.item(): - # TODO: deal with dict observations - real_next_obs[idx] = th.as_tensor(infos["final_observation"][idx], device=self.device) # type: ignore[index] - - # add new transitions - self.storage.add(obs, actions, rews, terms, truncs, infos, real_next_obs) - self.global_step += self.num_envs - - # deal with the intrinsic reward module - # for modules like RE3, this will calculate the random embeddings - # and insert them into the storage. for modules like ICM, this - # will update the dynamic models. - if self.irs is not None: - self.irs.add(samples={"obs": obs, "actions": actions, "next_obs": real_next_obs}) # type: ignore - - # get episode information - eps_r, eps_l = utils.get_episode_statistics(infos) - episode_rewards.extend(eps_r) - episode_steps.extend(eps_l) - self.global_episode += len(eps_r) - - # log training information - if len(episode_rewards) >= 1 and (self.global_step % log_interval) == 0: - total_time = self.timer.total_time() - - # log to console - train_metrics = { - "step": self.global_step, - "episode": self.global_episode, - "episode_length": np.mean(list(episode_steps)), - "episode_reward": np.mean(list(episode_rewards)), - "fps": self.global_step / total_time, - "total_time": total_time, - } - self.logger.train(msg=train_metrics) - - # set the current observation - obs = next_obs - - # save model - if self.global_step % save_interval == 0: - self.save() - - # final save - self.save() - self.logger.info("Training Accomplished!") - self.logger.info(f"Model saved at: {self.work_dir / 'model'}") - - # close env - self.env.close() - if self.eval_env is not None: - self.eval_env.close() - - def eval(self, num_eval_episodes: int) -> Dict[str, Any]: - """Evaluation function. - - Args: - num_eval_episodes (int): The number of evaluation episodes. - - Returns: - The evaluation results. - """ - assert self.eval_env is not None, "No evaluation environment is provided!" - # reset the env - obs, infos = self.eval_env.reset(seed=self.seed) - episode_rewards: List[float] = [] - episode_steps: List[int] = [] - - # evaluation loop - while len(episode_rewards) < num_eval_episodes: - # sample actions - with th.no_grad(), utils.eval_mode(self): - actions = self.policy(obs, training=False) - - # observe reward and next obs - next_obs, rews, terms, truncs, infos = self.eval_env.step(actions) - - # get episode information - if "episode" in infos: - eps_r, eps_l = utils.get_episode_statistics(infos) - episode_rewards.extend(eps_r) - episode_steps.extend(eps_l) - - # set the current observation - obs = next_obs - - return { - "step": self.global_step, - "episode": self.global_episode, - "episode_length": np.mean(episode_steps), - "episode_reward": np.mean(episode_rewards), - "total_time": self.timer.total_time(), - } diff --git a/rllte/common/prototype/on_policy_agent.py b/rllte/common/prototype/on_policy_agent.py index 2a77f740..b98f6a97 100644 --- a/rllte/common/prototype/on_policy_agent.py +++ b/rllte/common/prototype/on_policy_agent.py @@ -59,9 +59,11 @@ def __init__( device: str = "cpu", pretraining: bool = False, num_steps: int = 128, + use_lstm: bool = False, ) -> None: super().__init__(env=env, eval_env=eval_env, tag=tag, seed=seed, device=device, pretraining=pretraining) self.num_steps = num_steps + self.use_lstm = use_lstm # attr annotations self.policy: OnPolicyType self.storage: RolloutStorageType @@ -106,7 +108,19 @@ def train( # get number of updates num_updates = int(num_train_steps // self.num_envs // self.num_steps) + # only if using lstm, initialize lstm state + if self.use_lstm: + lstm_state = ( + th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), + th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), + ) + done = th.zeros(self.num_envs, dtype=th.bool, device=self.device) + for update in range(num_updates): + # important for updating the policy lstm later + if self.use_lstm: + self.initial_lstm_state = (lstm_state[0].clone(), lstm_state[1].clone()) + # try to eval if (update % eval_interval) == 0 and (self.eval_env is not None): eval_metrics = self.eval(num_eval_episodes) @@ -121,9 +135,18 @@ def train( for _ in range(self.num_steps): # sample actions with th.no_grad(), utils.eval_mode(self): - actions, extra_policy_outputs = self.policy(obs, training=True) + if self.use_lstm: + actions, extra_policy_outputs = self.policy(obs, lstm_state, done, training=True) + lstm_state = extra_policy_outputs["lstm_state"] + del extra_policy_outputs["lstm_state"] + else: + actions, extra_policy_outputs = self.policy(obs, training=True) + # observe rewards and next obs next_obs, rews, terms, truncs, infos = self.env.step(actions) + + if self.use_lstm: + done = th.logical_or(terms, truncs) # pre-training mode if self.pretraining: @@ -142,7 +165,10 @@ def train( # get the value estimation of the last step with th.no_grad(): - last_values = self.policy.get_value(next_obs).detach() + if self.use_lstm: + last_values = self.policy.get_value(next_obs, lstm_state, done).detach() + else: + last_values = self.policy.get_value(next_obs).detach() # perform return and advantage estimation self.storage.compute_returns_and_advantages(last_values) @@ -225,18 +251,41 @@ def eval(self, num_eval_episodes: int) -> Dict[str, Any]: episode_rewards: List[float] = [] episode_steps: List[int] = [] + if self.use_lstm: + lstm_state = ( + th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), + th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), + ) + done = th.zeros(self.num_envs, dtype=th.bool, device=self.device) + # evaluation loop while len(episode_rewards) < num_eval_episodes: with th.no_grad(), utils.eval_mode(self): - actions, _ = self.policy(obs, training=False) + if self.use_lstm: + actions, extra_policy_outputs = self.policy(obs, lstm_state, done, training=False) + lstm_state = extra_policy_outputs["lstm_state"] + del extra_policy_outputs["lstm_state"] + else: + actions, _ = self.policy(obs, training=False) + next_obs, rews, terms, truncs, infos = self.eval_env.step(actions) + if self.use_lstm: + done = th.logical_or(terms, truncs) + # get episode information if "episode" in infos: eps_r, eps_l = utils.get_episode_statistics(infos) episode_rewards.extend(eps_r) episode_steps.extend(eps_l) - + + if self.use_lstm: + lstm_state = ( + th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), + th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), + ) + done = th.zeros(self.num_envs, dtype=th.bool, device=self.device) + # set the current observation obs = next_obs diff --git a/rllte/common/type_alias.py b/rllte/common/type_alias.py index c0ad24f9..c6bd3932 100644 --- a/rllte/common/type_alias.py +++ b/rllte/common/type_alias.py @@ -137,6 +137,17 @@ class VanillaRolloutBatch(NamedTuple): truncateds: th.Tensor old_log_probs: th.Tensor adv_targ: th.Tensor + +class EpisodicRolloutBatch(NamedTuple): + observations: th.Tensor + actions: th.Tensor + values: th.Tensor + returns: th.Tensor + terminateds: th.Tensor + truncateds: th.Tensor + old_log_probs: th.Tensor + adv_targ: th.Tensor + env_inds: th.Tensor class DictRolloutBatch(NamedTuple): diff --git a/rllte/env/__init__.py b/rllte/env/__init__.py deleted file mode 100644 index 1cf0e2e8..00000000 --- a/rllte/env/__init__.py +++ /dev/null @@ -1,59 +0,0 @@ -# ============================================================================= -# MIT License - -# Copyright (c) 2023 Reinforcement Learning Evolution Foundation - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ============================================================================= - - -from .testing import make_bitflipping_env as make_bitflipping_env -from .testing import make_multibinary_env as make_multibinary_env -from .testing import make_multidiscrete_env as make_multidiscrete_env -from .testing import make_box_env as make_box_env -from .testing import make_discrete_env as make_discrete_env - -from .utils import make_rllte_env as make_rllte_env - -try: - from .atari import make_atari_env as make_atari_env - from .atari import make_envpool_atari_env as make_envpool_atari_env -except Exception: - pass - -try: - from .bullet import make_bullet_env as make_bullet_env -except Exception: - pass - -try: - from .dmc import make_dmc_env as make_dmc_env -except Exception: - pass - -try: - from .minigrid import make_minigrid_env as make_minigrid_env -except Exception: - pass - -try: - from .procgen import make_envpool_procgen_env as make_envpool_procgen_env - from .procgen import make_procgen_env as make_procgen_env -except Exception: - pass diff --git a/rllte/xploit/encoder/mnih_cnn_encoder.py b/rllte/xploit/encoder/mnih_cnn_encoder.py index aa0deed7..b0b76446 100644 --- a/rllte/xploit/encoder/mnih_cnn_encoder.py +++ b/rllte/xploit/encoder/mnih_cnn_encoder.py @@ -56,7 +56,7 @@ def __init__(self, observation_space: gym.Space, feature_dim: int = 0) -> None: nn.ReLU(), nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(), - nn.Conv2d(64, 32, 3, stride=1), + nn.Conv2d(64, 64, 3, stride=1), nn.ReLU(), nn.Flatten(), ) diff --git a/rllte/xploit/policy/__init__.py b/rllte/xploit/policy/__init__.py index 4c36b7be..d8939d92 100644 --- a/rllte/xploit/policy/__init__.py +++ b/rllte/xploit/policy/__init__.py @@ -30,3 +30,4 @@ from .off_policy_stoch_actor_double_critic import OffPolicyStochActorDoubleCritic as OffPolicyStochActorDoubleCritic from .on_policy_decoupled_actor_critic import OnPolicyDecoupledActorCritic as OnPolicyDecoupledActorCritic from .on_policy_shared_actor_critic import OnPolicySharedActorCritic as OnPolicySharedActorCritic +from .on_policy_shared_actor_critic_lstm import OnPolicySharedActorCriticLSTM as OnPolicySharedActorCriticLSTM \ No newline at end of file diff --git a/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py b/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py new file mode 100644 index 00000000..0a8c8c91 --- /dev/null +++ b/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py @@ -0,0 +1,270 @@ +# ============================================================================= +# MIT License + +# Copyright (c) 2023 Reinforcement Learning Evolution Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Type + +import gymnasium as gym +import torch as th +from torch import nn + +from rllte.common.prototype import BaseDistribution as Distribution +from rllte.common.prototype import BasePolicy +from rllte.common.utils import ExportModel +from .utils import OnPolicyCritic, get_on_policy_actor + +# from torch.distributions import Distribution + + +class OnPolicySharedActorCriticLSTM(BasePolicy): + """Actor-Critic network for on-policy algorithms like `PPO` and `A2C`. Contains LSTM. + + Args: + observation_space (gym.Space): Observation space. + action_space (gym.Space): Action space. + feature_dim (int): Number of features accepted. + hidden_dim (int): Number of units per hidden layer. + opt_class (Type[th.optim.Optimizer]): Optimizer class. + opt_kwargs (Dict[str, Any]): Optimizer keyword arguments. + aux_critic (bool): Use auxiliary critic or not, for `PPG` agent. + init_fn (str): Parameters initialization method. + + Returns: + Actor-Critic-LSTM network instance. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + feature_dim: int, + hidden_dim: int = 512, + opt_class: Type[th.optim.Optimizer] = th.optim.Adam, + opt_kwargs: Optional[Dict[str, Any]] = None, + aux_critic: bool = False, + init_fn: str = "orthogonal", + ) -> None: + if opt_kwargs is None: + opt_kwargs = {} + super().__init__( + observation_space=observation_space, + action_space=action_space, + feature_dim=feature_dim, + hidden_dim=hidden_dim, + opt_class=opt_class, + opt_kwargs=opt_kwargs, + init_fn=init_fn, + ) + + assert self.action_type in [ + "Discrete", + "Box", + "MultiBinary", + "MultiDiscrete", + ], f"Unsupported action type {self.action_type}!" + + # build lstm + self.lstm = nn.LSTM(feature_dim, feature_dim // 4) + for name, param in self.lstm.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + nn.init.orthogonal_(param, 1.0) + + # build actor and critic + actor_kwargs = dict( + obs_shape=self.obs_shape, + action_dim=self.policy_action_dim, + feature_dim=feature_dim // 4, + hidden_dim=self.hidden_dim, + ) + if self.nvec is not None: + actor_kwargs["nvec"] = self.nvec + + self.actor = get_on_policy_actor(action_type=self.action_type, actor_kwargs=actor_kwargs) + + self.critic = OnPolicyCritic( + obs_shape=self.obs_shape, + action_dim=self.policy_action_dim, + feature_dim=feature_dim // 4, + hidden_dim=self.hidden_dim, + ) + if aux_critic: + self.aux_critic = deepcopy(self.critic) + + @staticmethod + def describe() -> None: + """Describe the policy.""" + print("\n") + print("=" * 80) + print(f"{'Name'.ljust(10)} : OnPolicySharedActorCritic") + print(f"{'Structure'.ljust(10)} : self.encoder (shared by actor and critic), self.actor, self.critic") + print(f"{''.ljust(10)} : self.aux_critic (optional, for PPG)") + print(f"{'Forward'.ljust(10)} : obs -> self.encoder -> self.actor -> actions") + print(f"{''.ljust(10)} : obs -> self.encoder -> self.critic -> values") + print(f"{''.ljust(10)} : actions -> log_probs") + print(f"{'Optimizers'.ljust(10)} : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic)") + print(f"{''.ljust(10)} : self.optimizers['opt'] -> self.aux_critic (optional, for PPG)") + print("=" * 80) + print("\n") + + def freeze(self, encoder: nn.Module, dist: Distribution) -> None: + """Freeze all the elements like `encoder` and `dist`. + + Args: + encoder (nn.Module): Encoder network. + dist (Distribution): Distribution class. + + Returns: + None. + """ + # set encoder + assert encoder is not None, "Encoder should not be None!" + self.encoder = encoder + # set distribution + assert dist is not None, "Distribution should not be None!" + self.dist = dist + # initialize parameters + self.apply(self.init_fn) + # build optimizers + self._optimizers["opt"] = self.opt_class(self.parameters(), **self.opt_kwargs) + + def get_states(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor): + hidden = self.encoder(obs) + + # LSTM logic + batch_size = lstm_state[0].shape[1] + hidden = hidden.reshape((-1, batch_size, self.lstm.input_size)) + done = done.reshape((-1, batch_size)).long() + new_hidden = [] + for h, d in zip(hidden, done): + h, lstm_state = self.lstm( + h.unsqueeze(0), + ( + (1.0 - d).view(1, -1, 1) * lstm_state[0], + (1.0 - d).view(1, -1, 1) * lstm_state[1], + ), + ) + new_hidden += [h] + new_hidden = th.flatten(th.cat(new_hidden), 0, 1) + return new_hidden, lstm_state + + def forward(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor, training: bool = True) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: + """Get actions and estimated values for observations. + + Args: + obs (th.Tensor): Observations. + training (bool): training mode, `True` or `False`. + + Returns: + Sampled actions, estimated values, and log of probabilities for observations when `training` is `True`, + else only deterministic actions. + """ + h, lstm_state = self.get_states(obs, lstm_state, done) + + policy_outputs = self.actor.get_policy_outputs(h) + dist = self.dist(*policy_outputs) + + if training: + actions = dist.sample() + log_probs = dist.log_prob(actions) + return actions, {"values": self.critic(h), "log_probs": log_probs, "lstm_state": lstm_state} + else: + actions = dist.mean + return actions, {"lstm_state": lstm_state} + + def get_value(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> th.Tensor: + """Get estimated values for observations. + + Args: + obs (th.Tensor): Observations. + + Returns: + Estimated values. + """ + return self.critic(self.get_states(obs, lstm_state, done)[0]) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """Evaluate actions according to the current policy given the observations. + + Args: + obs (th.Tensor): Sampled observations. + actions (th.Tensor): Sampled actions. + + Returns: + Estimated values, log of the probability evaluated at `actions`, entropy of distribution. + """ + h, _ = self.get_states(obs, lstm_state, done) + policy_outputs = self.actor.get_policy_outputs(h) + dist = self.dist(*policy_outputs) + + log_probs = dist.log_prob(actions) + entropy = dist.entropy().mean() + + return self.critic(h), log_probs, entropy + + def get_policy_outputs(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> th.Tensor: + """Get policy outputs for training. + + Args: + obs (Tensor): Observations. + + Returns: + Policy outputs like unnormalized probabilities for `Discrete` tasks. + """ + h, _ = self.get_states(obs, lstm_state, done) + policy_outputs = self.actor.get_policy_outputs(h) + return th.cat(policy_outputs, dim=1) + + def get_dist_and_aux_value(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> Tuple[Distribution, th.Tensor, th.Tensor]: + """Get probs and auxiliary estimated values for auxiliary phase update. + + Args: + obs: Sampled observations. + + Returns: + Sample distribution, estimated values, auxiliary estimated values. + """ + h, _ = self.get_states(obs, lstm_state, done) + policy_outputs = self.actor.get_policy_outputs(h) + dist = self.dist(*policy_outputs) + + return dist, self.critic(h.detach()), self.aux_critic(h) + + def save(self, path: Path, pretraining: bool, global_step: int) -> None: + """Save models. + + Args: + path (Path): Save path. + pretraining (bool): Pre-training mode. + global_step (int): Global training step. + + Returns: + None. + """ + if pretraining: # pretraining + th.save(self.state_dict(), path / f"pretrained_{global_step}.pth") + else: + export_model = ExportModel(encoder=self.encoder, actor=self.actor) + th.save(export_model, path / f"agent_{global_step}.pth") diff --git a/rllte/xploit/storage/__init__.py b/rllte/xploit/storage/__init__.py index fbc3683d..c0ffb58c 100644 --- a/rllte/xploit/storage/__init__.py +++ b/rllte/xploit/storage/__init__.py @@ -30,3 +30,4 @@ from .vanilla_distributed_storage import VanillaDistributedStorage as VanillaDistributedStorage from .vanilla_replay_storage import VanillaReplayStorage as VanillaReplayStorage from .vanilla_rollout_storage import VanillaRolloutStorage as VanillaRolloutStorage +from .episodic_rollout_storage import EpisodicRolloutStorage as EpisodicRolloutStorage \ No newline at end of file diff --git a/rllte/xploit/storage/episodic_rollout_storage.py b/rllte/xploit/storage/episodic_rollout_storage.py new file mode 100644 index 00000000..53b1e989 --- /dev/null +++ b/rllte/xploit/storage/episodic_rollout_storage.py @@ -0,0 +1,196 @@ +# ============================================================================= +# MIT License + +# Copyright (c) 2023 Reinforcement Learning Evolution Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + + +from typing import Dict, Generator + +import gymnasium as gym +import torch as th +import numpy as np +from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler + +from rllte.common.prototype import BaseStorage +from rllte.common.type_alias import EpisodicRolloutBatch + + +class EpisodicRolloutStorage(BaseStorage): + """Episodic rollout storage for on-policy algorithms that use an LSTM. + It is the same as VanillaRolloutStorage but samples enitre trajectories instead of batches of different steps. + + Args: + observation_space (gym.Space): The observation space of environment. + action_space (gym.Space): The action space of environment. + device (str): Device to convert the data. + storage_size (int): The capacity of the storage. Here it refers to the length of per rollout. + batch_size (int): Batch size of samples. + num_envs (int): The number of parallel environments. + discount (float): The discount factor. + gae_lambda (float): Weighting coefficient for generalized advantage estimation (GAE). + + Returns: + Vanilla rollout storage. + """ + + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + device: str = "cpu", + storage_size: int = 256, + num_envs: int = 8, + discount: float = 0.999, + gae_lambda: float = 0.95, + num_batches: int = 4, + ) -> None: + batch_size = (num_envs * storage_size) // num_batches + super().__init__(observation_space, action_space, device, storage_size, batch_size, num_envs) + self.discount = discount + self.gae_lambda = gae_lambda + self.num_batches = num_batches + self.reset() + + def reset(self) -> None: + """Reset the storage.""" + # data containers + self.observations = th.empty( + size=(self.storage_size + 1, self.num_envs, *self.obs_shape), dtype=th.float32, device=self.device + ) + self.actions = th.empty(size=(self.storage_size, self.num_envs, self.action_dim), dtype=th.float32, device=self.device) + self.rewards = th.empty(size=(self.storage_size, self.num_envs), dtype=th.float32, device=self.device) + self.terminateds = th.empty(size=(self.storage_size + 1, self.num_envs), dtype=th.float32, device=self.device) + self.truncateds = th.empty(size=(self.storage_size + 1, self.num_envs), dtype=th.float32, device=self.device) + # first next_terminated + self.terminateds[0].fill_(0.0) + self.truncateds[0].fill_(0.0) + # extra part + self.log_probs = th.empty(size=(self.storage_size, self.num_envs), dtype=th.float32, device=self.device) + self.values = th.empty(size=(self.storage_size, self.num_envs), dtype=th.float32, device=self.device) + self.returns = th.empty(size=(self.storage_size, self.num_envs), dtype=th.float32, device=self.device) + self.advantages = th.empty(size=(self.storage_size, self.num_envs), dtype=th.float32, device=self.device) + super().reset() + + def add( + self, + observations: th.Tensor, + actions: th.Tensor, + rewards: th.Tensor, + terminateds: th.Tensor, + truncateds: th.Tensor, + infos: Dict, + next_observations: th.Tensor, + log_probs: th.Tensor, + values: th.Tensor, + ) -> None: + """Add sampled transitions into storage. + + Args: + observations (th.Tensor): Observations. + actions (th.Tensor): Actions. + rewards (th.Tensor): Rewards. + terminateds (th.Tensor): Termination signals. + truncateds (th.Tensor): Truncation signals. + infos (Dict): Extra information. + next_observations (th.Tensor): Next observations. + log_probs (th.Tensor): Log of the probability evaluated at `actions`. + values (th.Tensor): Estimated values. + + Returns: + None. + """ + self.observations[self.step].copy_(observations) + self.actions[self.step].copy_(actions.view(self.num_envs, self.action_dim)) + self.rewards[self.step].copy_(rewards) + self.terminateds[self.step + 1].copy_(terminateds) + self.truncateds[self.step + 1].copy_(truncateds) + self.observations[self.step + 1].copy_(next_observations) + self.log_probs[self.step].copy_(log_probs) + self.values[self.step].copy_(values.flatten()) + + self.full = True if self.step == self.storage_size - 1 else False + self.step = (self.step + 1) % self.storage_size + + def update(self) -> None: + """Update the terminal state of each env.""" + self.terminateds[0].copy_(self.terminateds[-1]) + self.truncateds[0].copy_(self.truncateds[-1]) + + def compute_returns_and_advantages(self, last_values: th.Tensor) -> None: + """Perform generalized advantage estimation (GAE). + + Args: + last_values (th.Tensor): Estimated values of the last step. + + Returns: + None. + """ + gae = 0 + for step in reversed(range(self.storage_size)): + if step == self.storage_size - 1: + next_values = last_values[:, 0] + else: + next_values = self.values[step + 1] + next_non_terminal = 1.0 - self.terminateds[step + 1] + delta = self.rewards[step] + self.discount * next_values * next_non_terminal - self.values[step] + gae = delta + self.discount * self.gae_lambda * next_non_terminal * gae + # time limit + gae = gae * (1.0 - self.truncateds[step + 1]) + self.advantages[step] = gae + + self.returns = self.advantages + self.values + self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-5) + + def sample(self) -> Generator: + """ + Choose a minibatch of environment indices and sample the entire rollout for those minibatches. + By not sampling uniform transitions, we can now train an LSTM model on entire trajectories + """ + assert self.full, "Cannot sample when the storage is not full!" + _batch_size = self.num_envs // self.num_batches + sampler = BatchSampler(SubsetRandomSampler(range(self.num_envs)), _batch_size, drop_last=True) + + b_obs = self.observations[:-1].reshape(-1, *self.obs_shape) + b_act = self.actions.reshape(-1, *self.action_shape) + b_val = self.values.reshape(-1) + b_ret = self.returns.reshape(-1) + b_ter = self.terminateds[:-1].reshape(-1) + b_tru = self.truncateds[:-1].reshape(-1) + b_log = self.log_probs.reshape(-1) + b_adv = self.advantages.reshape(-1) + + flat_idcs = np.arange(self.num_envs * self.storage_size).reshape(self.storage_size, self.num_envs) + + for indices in sampler: + mb_inds = flat_idcs[:, indices].ravel() + + yield EpisodicRolloutBatch( + observations=b_obs[mb_inds], + actions=b_act[mb_inds], + values=b_val[mb_inds], + returns=b_ret[mb_inds], + terminateds=b_ter[mb_inds], + truncateds=b_tru[mb_inds], + old_log_probs=b_log[mb_inds], + adv_targ=b_adv[mb_inds], + env_inds=indices, + ) \ No newline at end of file From ccc89fed76f0d6c86458792ffbc332098a423710 Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 23 Nov 2023 10:23:28 -0500 Subject: [PATCH 2/5] added training example --- examples/train_ppoLstm_atari.ipynb | 74 ++++++++++++++++++++++++++++++ rllte/agent/ppo_lstm.py | 2 +- 2 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 examples/train_ppoLstm_atari.ipynb diff --git a/examples/train_ppoLstm_atari.ipynb b/examples/train_ppoLstm_atari.ipynb new file mode 100644 index 00000000..57744373 --- /dev/null +++ b/examples/train_ppoLstm_atari.ipynb @@ -0,0 +1,74 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "from rllte.agent import PPO_LSTM\n", + "from rllte.env import make_envpool_atari_env" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cuda:0\"\n", + "num_envs = 16\n", + "\n", + "env = make_envpool_atari_env(\n", + " env_id=\"SpaceInvaders-v5\",\n", + " device=device,\n", + " num_envs=num_envs,\n", + " asynchronous=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agent = PPO_LSTM(\n", + " env=env, \n", + " device=device,\n", + " tag=\"ppo_lstm_atari\",\n", + ")\n", + "\n", + "print(\"===== AGENT =====\")\n", + "print(agent.encoder)\n", + "print(agent.policy)\n", + "\n", + "agent.train(num_train_steps=10_000_000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "rllte", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/rllte/agent/ppo_lstm.py b/rllte/agent/ppo_lstm.py index 074c154f..6f525cfb 100644 --- a/rllte/agent/ppo_lstm.py +++ b/rllte/agent/ppo_lstm.py @@ -87,7 +87,7 @@ def __init__( vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: float = 0.5, - discount: float = 0.999, + discount: float = 0.99, init_fn: str = "orthogonal", num_batches: int = 4, ) -> None: From c99234b2944cebf531fcb7a66e6e5d15c3760228 Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 23 Nov 2023 10:32:18 -0500 Subject: [PATCH 3/5] Fixed missing scripts from main --- rllte/env/minigrid/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllte/env/minigrid/__init__.py b/rllte/env/minigrid/__init__.py index bb14c2a1..d1dec9a6 100644 --- a/rllte/env/minigrid/__init__.py +++ b/rllte/env/minigrid/__init__.py @@ -145,4 +145,4 @@ def _thunk(): envs = SyncVectorEnv(envs) envs = RecordEpisodeStatistics(envs) - return Gymnasium2Torch(envs, device=device) + return Gymnasium2Torch(envs, device=device) \ No newline at end of file From 480d41f181c3cfb948a65f08d3da6a748c201637 Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 23 Nov 2023 10:37:40 -0500 Subject: [PATCH 4/5] run make format --- rllte/agent/ppo_lstm.py | 2 +- rllte/common/prototype/base_agent.py | 2 +- rllte/common/prototype/base_distribution.py | 1 + rllte/common/prototype/off_policy_agent.py | 247 ++++++++++++++++++ rllte/env/__init__.py | 58 ++++ rllte/env/atari/__init__.py | 20 +- rllte/env/atari/wrappers.py | 2 +- rllte/env/testing/__init__.py | 2 +- rllte/hub/atari.py | 12 +- rllte/hub/bucket.py | 4 +- rllte/hub/dmc.py | 9 +- rllte/hub/minigrid.py | 9 +- rllte/hub/procgen.py | 11 +- .../on_policy_shared_actor_critic_lstm.py | 1 + rllte/xploit/policy/utils.py | 2 +- rllte/xploit/storage/__init__.py | 2 +- rllte/xploit/storage/dict_replay_storage.py | 3 +- .../storage/episodic_rollout_storage.py | 2 +- tests/test_reward.py | 6 +- tests/test_storage.py | 2 +- 20 files changed, 352 insertions(+), 45 deletions(-) create mode 100644 rllte/common/prototype/off_policy_agent.py create mode 100644 rllte/env/__init__.py diff --git a/rllte/agent/ppo_lstm.py b/rllte/agent/ppo_lstm.py index 6f525cfb..59a96338 100644 --- a/rllte/agent/ppo_lstm.py +++ b/rllte/agent/ppo_lstm.py @@ -31,7 +31,7 @@ from rllte.common.prototype import OnPolicyAgent from rllte.common.type_alias import VecEnv -from rllte.xploit.encoder import IdentityEncoder, MnihCnnEncoder, EspeholtResidualEncoder, PathakCnnEncoder +from rllte.xploit.encoder import EspeholtResidualEncoder, IdentityEncoder, MnihCnnEncoder, PathakCnnEncoder from rllte.xploit.policy import OnPolicySharedActorCriticLSTM from rllte.xploit.storage import EpisodicRolloutStorage from rllte.xplore.distribution import Bernoulli, Categorical, DiagonalGaussian, MultiCategorical diff --git a/rllte/common/prototype/base_agent.py b/rllte/common/prototype/base_agent.py index 5aba434b..1c7b0c12 100644 --- a/rllte/common/prototype/base_agent.py +++ b/rllte/common/prototype/base_agent.py @@ -28,7 +28,7 @@ from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Dict, Optional, Any +from typing import Any, Dict, Optional import numpy as np import pynvml diff --git a/rllte/common/prototype/base_distribution.py b/rllte/common/prototype/base_distribution.py index 9cae29b1..fc1e525e 100644 --- a/rllte/common/prototype/base_distribution.py +++ b/rllte/common/prototype/base_distribution.py @@ -24,6 +24,7 @@ from typing import Any + import torch as th from torch.distributions import Distribution diff --git a/rllte/common/prototype/off_policy_agent.py b/rllte/common/prototype/off_policy_agent.py new file mode 100644 index 00000000..27bd04f4 --- /dev/null +++ b/rllte/common/prototype/off_policy_agent.py @@ -0,0 +1,247 @@ +# ============================================================================= +# MIT License + +# Copyright (c) 2023 Reinforcement Learning Evolution Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + + +from collections import deque +from copy import deepcopy +from typing import Any, Deque, Dict, List, Optional + +import numpy as np +import torch as th + +from rllte.common import utils +from rllte.common.prototype.base_agent import BaseAgent +from rllte.common.type_alias import OffPolicyType, ReplayStorageType, VecEnv + + +class OffPolicyAgent(BaseAgent): + """Trainer for off-policy algorithms. + + Args: + env (VecEnv): Vectorized environments for training. + eval_env (Optional[VecEnv]): Vectorized environments for evaluation. + tag (str): An experiment tag. + seed (int): Random seed for reproduction. + device (str): Device (cpu, cuda, ...) on which the code should be run. + pretraining (bool): Turn on pre-training model or not. + num_init_steps (int): Number of initial exploration steps. + **kwargs: Arbitrary arguments such as `batch_size` and `hidden_dim`. + + Returns: + Off-policy agent instance. + """ + + def __init__( + self, + env: VecEnv, + eval_env: Optional[VecEnv] = None, + tag: str = "default", + seed: int = 1, + device: str = "cpu", + pretraining: bool = False, + num_init_steps: int = 2000, + **kwargs, + ) -> None: + super().__init__(env=env, eval_env=eval_env, tag=tag, seed=seed, device=device, pretraining=pretraining) + self.num_init_steps = num_init_steps + # attr annotations + self.policy: OffPolicyType + self.storage: ReplayStorageType + + def update(self) -> None: + """Update the agent. Implemented by individual algorithms.""" + raise NotImplementedError + + def train( # noqa: C901 + self, + num_train_steps: int, + init_model_path: Optional[str] = None, + log_interval: int = 1, + eval_interval: int = 5000, + save_interval: int = 5000, + num_eval_episodes: int = 10, + th_compile: bool = False, + anneal_lr: bool = False + ) -> None: + """Training function. + + Args: + num_train_steps (int): The number of training steps. + init_model_path (Optional[str]): The path of the initial model. + log_interval (int): The interval of logging. + eval_interval (int): The interval of evaluation. + save_interval (int): The interval of saving model. + num_eval_episodes (int): The number of evaluation episodes. + th_compile (bool): Whether to use `th.compile` or not. + anneal_lr (bool): Whether to anneal the learning rate or not. + + Returns: + None. + """ + # freeze the agent and get ready for training + self.freeze(init_model_path=init_model_path, th_compile=th_compile) + + # reset the env + episode_rewards: Deque = deque(maxlen=10) + episode_steps: Deque = deque(maxlen=10) + obs, infos = self.env.reset(seed=self.seed) + + # training loop + while self.global_step < num_train_steps: + # try to eval + if (self.global_step % eval_interval) == 0 and (self.eval_env is not None): + eval_metrics = self.eval(num_eval_episodes) + + # log to console + self.logger.eval(msg=eval_metrics) + + # sample actions + with th.no_grad(), utils.eval_mode(self): + # Initial exploration + if self.global_step < self.num_init_steps: + actions = th.stack([th.as_tensor(self.action_space.sample()) for _ in range(self.num_envs)]) + else: + actions = self.policy(obs, training=True) + + # update the learning rate + if anneal_lr: + for key in self.policy.optimizers.keys(): + utils.linear_lr_scheduler(self.policy.optimizers[key], self.global_step, num_train_steps, self.lr) + + # update agent + if self.global_step >= self.num_init_steps: + self.update() + # try to update storage + self.storage.update(self.metrics) + + # observe reward and next obs + next_obs, rews, terms, truncs, infos = self.env.step(actions) + + # pre-training mode + if self.pretraining: + rews = th.zeros_like(rews, device=self.device) + + # TODO: get real next observations + # As the vector environments autoreset for a terminating and truncating sub-environments, + # the returned observation and info is not the final step's observation or info which + # is instead stored in info as `final_observation` and `final_info`. So we need to get + # the real next observations from the infos and not to reset the environments. + real_next_obs = deepcopy(next_obs) + for idx, (term, trunc) in enumerate(zip(terms, truncs)): + if term.item() or trunc.item(): + # TODO: deal with dict observations + try: + real_next_obs[idx] = th.as_tensor(infos["final_observation"][idx], device=self.device) # type: ignore[index] + except: + pass + + # add new transitions + self.storage.add(obs, actions.unsqueeze(-1), rews, terms, truncs, infos, real_next_obs) + self.global_step += self.num_envs + + # deal with the intrinsic reward module + # for modules like RE3, this will calculate the random embeddings + # and insert them into the storage. for modules like ICM, this + # will update the dynamic models. + if self.irs is not None: + self.irs.add(samples={"obs": obs, "actions": actions, "next_obs": real_next_obs}) # type: ignore + + # get episode information + eps_r, eps_l = utils.get_episode_statistics(infos) + episode_rewards.extend(eps_r) + episode_steps.extend(eps_l) + self.global_episode += len(eps_r) + + # log training information + if len(episode_rewards) >= 1 and (self.global_step % log_interval) == 0: + total_time = self.timer.total_time() + + # log to console + train_metrics = { + "step": self.global_step, + "episode": self.global_episode, + "episode_length": np.mean(list(episode_steps)), + "episode_reward": np.mean(list(episode_rewards)), + "fps": self.global_step / total_time, + "total_time": total_time, + } + self.logger.train(msg=train_metrics) + + # set the current observation + obs = next_obs + + # save model + if self.global_step % save_interval == 0: + self.save() + + # final save + self.save() + self.logger.info("Training Accomplished!") + self.logger.info(f"Model saved at: {self.work_dir / 'model'}") + + # close env + self.env.close() + if self.eval_env is not None: + self.eval_env.close() + + def eval(self, num_eval_episodes: int) -> Dict[str, Any]: + """Evaluation function. + + Args: + num_eval_episodes (int): The number of evaluation episodes. + + Returns: + The evaluation results. + """ + assert self.eval_env is not None, "No evaluation environment is provided!" + # reset the env + obs, infos = self.eval_env.reset(seed=self.seed) + episode_rewards: List[float] = [] + episode_steps: List[int] = [] + + # evaluation loop + while len(episode_rewards) < num_eval_episodes: + # sample actions + with th.no_grad(), utils.eval_mode(self): + actions = self.policy(obs, training=False) + + # observe reward and next obs + next_obs, rews, terms, truncs, infos = self.eval_env.step(actions) + + # get episode information + if "episode" in infos: + eps_r, eps_l = utils.get_episode_statistics(infos) + episode_rewards.extend(eps_r) + episode_steps.extend(eps_l) + + # set the current observation + obs = next_obs + + return { + "step": self.global_step, + "episode": self.global_episode, + "episode_length": np.mean(episode_steps), + "episode_reward": np.mean(episode_rewards), + "total_time": self.timer.total_time(), + } diff --git a/rllte/env/__init__.py b/rllte/env/__init__.py new file mode 100644 index 00000000..588ebee7 --- /dev/null +++ b/rllte/env/__init__.py @@ -0,0 +1,58 @@ +# ============================================================================= +# MIT License + +# Copyright (c) 2023 Reinforcement Learning Evolution Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ============================================================================= + + +from .testing import make_bitflipping_env as make_bitflipping_env +from .testing import make_box_env as make_box_env +from .testing import make_discrete_env as make_discrete_env +from .testing import make_multibinary_env as make_multibinary_env +from .testing import make_multidiscrete_env as make_multidiscrete_env +from .utils import make_rllte_env as make_rllte_env + +try: + from .atari import make_atari_env as make_atari_env + from .atari import make_envpool_atari_env as make_envpool_atari_env +except Exception: + pass + +try: + from .bullet import make_bullet_env as make_bullet_env +except Exception: + pass + +try: + from .dmc import make_dmc_env as make_dmc_env +except Exception: + pass + +try: + from .minigrid import make_minigrid_env as make_minigrid_env +except Exception: + pass + +try: + from .procgen import make_envpool_procgen_env as make_envpool_procgen_env + from .procgen import make_procgen_env as make_procgen_env +except Exception: + pass \ No newline at end of file diff --git a/rllte/env/atari/__init__.py b/rllte/env/atari/__init__.py index 5f9b3717..65c7df81 100644 --- a/rllte/env/atari/__init__.py +++ b/rllte/env/atari/__init__.py @@ -28,17 +28,15 @@ import gymnasium as gym import numpy as np from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv -from gymnasium.wrappers import (FrameStack, - GrayScaleObservation, - RecordEpisodeStatistics, - ResizeObservation, - TransformReward) - -from rllte.env.atari.wrappers import (EpisodicLifeEnv, - FireResetEnv, - MaxAndSkipEnv, - NoopResetEnv, - RecordEpisodeStatistics4EnvPool) +from gymnasium.wrappers import FrameStack, GrayScaleObservation, RecordEpisodeStatistics, ResizeObservation, TransformReward + +from rllte.env.atari.wrappers import ( + EpisodicLifeEnv, + FireResetEnv, + MaxAndSkipEnv, + NoopResetEnv, + RecordEpisodeStatistics4EnvPool, +) from rllte.env.utils import EnvPoolAsync2Gymnasium, EnvPoolSync2Gymnasium, Gymnasium2Torch diff --git a/rllte/env/atari/wrappers.py b/rllte/env/atari/wrappers.py index a3e85488..231c0080 100644 --- a/rllte/env/atari/wrappers.py +++ b/rllte/env/atari/wrappers.py @@ -23,7 +23,7 @@ # ============================================================================= -from typing import Any, Dict, Tuple, Optional +from typing import Any, Dict, Optional, Tuple import gymnasium as gym import numpy as np diff --git a/rllte/env/testing/__init__.py b/rllte/env/testing/__init__.py index 5b8dbc97..02095983 100644 --- a/rllte/env/testing/__init__.py +++ b/rllte/env/testing/__init__.py @@ -23,8 +23,8 @@ # ============================================================================= +from .bitflipping import make_bitflipping_env as make_bitflipping_env from .box import make_box_env as make_box_env from .discrete import make_discrete_env as make_discrete_env -from .bitflipping import make_bitflipping_env as make_bitflipping_env from .multibinary import make_multibinary_env as make_multibinary_env from .multidiscrete import make_multidiscrete_env as make_multidiscrete_env diff --git a/rllte/hub/atari.py b/rllte/hub/atari.py index c2484fcc..c51b7de4 100644 --- a/rllte/hub/atari.py +++ b/rllte/hub/atari.py @@ -23,17 +23,17 @@ # ============================================================================= -from huggingface_hub import hf_hub_download -from typing import Dict, Callable -from torch import nn +from typing import Callable, Dict import numpy as np import torch as th +from huggingface_hub import hf_hub_download +from torch import nn -from rllte.hub.bucket import Bucket -from rllte.agent import A2C, PPO, IMPALA -from rllte.env import make_atari_env, make_envpool_atari_env +from rllte.agent import A2C, IMPALA, PPO from rllte.common.prototype import BaseAgent +from rllte.env import make_atari_env, make_envpool_atari_env +from rllte.hub.bucket import Bucket class Atari(Bucket): diff --git a/rllte/hub/bucket.py b/rllte/hub/bucket.py index faeb56da..e8b54570 100644 --- a/rllte/hub/bucket.py +++ b/rllte/hub/bucket.py @@ -25,11 +25,13 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, List, Optional -from torch import nn + import numpy as np +from torch import nn from rllte.common.prototype import BaseAgent + class Bucket(ABC): """Bucket class for storing scores, learning curves, and models.""" def __init__(self) -> None: diff --git a/rllte/hub/dmc.py b/rllte/hub/dmc.py index a795cd3b..65048d7a 100644 --- a/rllte/hub/dmc.py +++ b/rllte/hub/dmc.py @@ -23,16 +23,17 @@ # ============================================================================= -from huggingface_hub import hf_hub_download from typing import Dict, Optional -from torch import nn import numpy as np import torch as th -from rllte.hub.bucket import Bucket +from huggingface_hub import hf_hub_download +from torch import nn + from rllte.agent import SAC, DrQv2 -from rllte.env import make_dmc_env from rllte.common.prototype import BaseAgent +from rllte.env import make_dmc_env +from rllte.hub.bucket import Bucket # cheetah_run quadruped_walk quadruped_run walker_walk walker_run hopper_hop arcobot_swingup cup_catch # cartpole_balance cartpole_balance_sparse cartpole_swingup cartpole_swingup_sparse finger_spin finger_turn_easy diff --git a/rllte/hub/minigrid.py b/rllte/hub/minigrid.py index ab31ba53..c72f5cd2 100644 --- a/rllte/hub/minigrid.py +++ b/rllte/hub/minigrid.py @@ -23,16 +23,17 @@ # ============================================================================= -from huggingface_hub import hf_hub_download from typing import Dict, Optional -from torch import nn import numpy as np import torch as th -from rllte.hub.bucket import Bucket +from huggingface_hub import hf_hub_download +from torch import nn + from rllte.agent import A2C, PPO -from rllte.env import make_minigrid_env from rllte.common.prototype import BaseAgent +from rllte.env import make_minigrid_env +from rllte.hub.bucket import Bucket class MiniGrid(Bucket): diff --git a/rllte/hub/procgen.py b/rllte/hub/procgen.py index 47e88fbb..d07677b6 100644 --- a/rllte/hub/procgen.py +++ b/rllte/hub/procgen.py @@ -23,17 +23,18 @@ # ============================================================================= -from huggingface_hub import hf_hub_download from typing import Dict, Optional -from torch import nn import numpy as np import torch as th +from huggingface_hub import hf_hub_download +from torch import nn + +from rllte.agent import DAAC, PPG, PPO +from rllte.common.prototype import BaseAgent +from rllte.env import make_envpool_procgen_env from rllte.hub.bucket import Bucket -from rllte.agent import PPO, PPG, DAAC from rllte.xploit.encoder import EspeholtResidualEncoder -from rllte.env import make_envpool_procgen_env -from rllte.common.prototype import BaseAgent class Procgen(Bucket): diff --git a/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py b/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py index 0a8c8c91..70d48c6f 100644 --- a/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py +++ b/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py @@ -32,6 +32,7 @@ from rllte.common.prototype import BaseDistribution as Distribution from rllte.common.prototype import BasePolicy from rllte.common.utils import ExportModel + from .utils import OnPolicyCritic, get_on_policy_actor # from torch.distributions import Distribution diff --git a/rllte/xploit/policy/utils.py b/rllte/xploit/policy/utils.py index e093cf87..dd00fd5c 100644 --- a/rllte/xploit/policy/utils.py +++ b/rllte/xploit/policy/utils.py @@ -29,7 +29,7 @@ from torch import nn from torch.nn import functional as F -from rllte.common.type_alias import ObsShape, BaseDistribution +from rllte.common.type_alias import BaseDistribution, ObsShape class OnPolicyDiscreteActor(nn.Module): diff --git a/rllte/xploit/storage/__init__.py b/rllte/xploit/storage/__init__.py index c0ffb58c..31fcf4f3 100644 --- a/rllte/xploit/storage/__init__.py +++ b/rllte/xploit/storage/__init__.py @@ -24,10 +24,10 @@ from .dict_replay_storage import DictReplayStorage as DictReplayStorage from .dict_rollout_storage import DictRolloutStorage as DictRolloutStorage +from .episodic_rollout_storage import EpisodicRolloutStorage as EpisodicRolloutStorage from .her_replay_storage import HerReplayStorage as HerReplayStorage from .nstep_replay_storage import NStepReplayStorage as NStepReplayStorage from .prioritized_replay_storage import PrioritizedReplayStorage as PrioritizedReplayStorage from .vanilla_distributed_storage import VanillaDistributedStorage as VanillaDistributedStorage from .vanilla_replay_storage import VanillaReplayStorage as VanillaReplayStorage from .vanilla_rollout_storage import VanillaRolloutStorage as VanillaRolloutStorage -from .episodic_rollout_storage import EpisodicRolloutStorage as EpisodicRolloutStorage \ No newline at end of file diff --git a/rllte/xploit/storage/dict_replay_storage.py b/rllte/xploit/storage/dict_replay_storage.py index fde1f3f0..9ce464ac 100644 --- a/rllte/xploit/storage/dict_replay_storage.py +++ b/rllte/xploit/storage/dict_replay_storage.py @@ -29,8 +29,9 @@ import numpy as np import torch as th -from rllte.xploit.storage.vanilla_replay_storage import VanillaReplayStorage from rllte.common.type_alias import DictReplayBatch +from rllte.xploit.storage.vanilla_replay_storage import VanillaReplayStorage + class DictReplayStorage(VanillaReplayStorage): """Dict replay storage for off-policy algorithms and dictionary observations. diff --git a/rllte/xploit/storage/episodic_rollout_storage.py b/rllte/xploit/storage/episodic_rollout_storage.py index 53b1e989..18352589 100644 --- a/rllte/xploit/storage/episodic_rollout_storage.py +++ b/rllte/xploit/storage/episodic_rollout_storage.py @@ -26,8 +26,8 @@ from typing import Dict, Generator import gymnasium as gym -import torch as th import numpy as np +import torch as th from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler from rllte.common.prototype import BaseStorage diff --git a/tests/test_reward.py b/tests/test_reward.py index 4951be73..bb828008 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -1,11 +1,7 @@ import pytest import torch as th -from rllte.env.testing import (make_box_env, - make_discrete_env, - make_multibinary_env, - make_multidiscrete_env - ) +from rllte.env.testing import make_box_env, make_discrete_env, make_multibinary_env, make_multidiscrete_env from rllte.xplore.reward import GIRM, ICM, NGU, RE3, REVD, RIDE, RISE, RND, PseudoCounts diff --git a/tests/test_storage.py b/tests/test_storage.py index 9cdb6872..d9ee6edf 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,7 +1,7 @@ import pytest import torch as th -from rllte.env.testing import make_box_env, make_bitflipping_env +from rllte.env.testing import make_bitflipping_env, make_box_env from rllte.xploit.storage import ( DictReplayStorage, DictRolloutStorage, From cf4a330bf9062032f1be7b4efcb95ab5c3de5474 Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 23 Nov 2023 10:44:52 -0500 Subject: [PATCH 5/5] ran make lint --- rllte/agent/__init__.py | 2 +- rllte/agent/daac.py | 4 +- rllte/agent/drac.py | 2 +- rllte/agent/drdaac.py | 2 +- rllte/agent/legacy/a2c.py | 2 +- rllte/agent/legacy/dqn.py | 2 +- rllte/agent/legacy/ppo.py | 2 +- rllte/agent/legacy/sacd.py | 4 +- rllte/agent/ppg.py | 4 +- rllte/agent/ppo_lstm.py | 6 +- rllte/common/prototype/base_distribution.py | 2 +- rllte/common/prototype/off_policy_agent.py | 11 +- rllte/common/prototype/on_policy_agent.py | 14 +- rllte/common/type_alias.py | 3 +- rllte/common/utils.py | 6 +- rllte/env/__init__.py | 2 +- rllte/env/atari/__init__.py | 2 +- rllte/env/atari/wrappers.py | 17 +-- rllte/env/dmc/wrappers.py | 2 +- rllte/env/minigrid/__init__.py | 2 +- rllte/env/testing/box.py | 12 +- rllte/env/testing/discrete.py | 12 +- rllte/env/testing/multibinary.py | 12 +- rllte/env/testing/multidiscrete.py | 10 +- rllte/hub/__init__.py | 2 +- rllte/hub/atari.py | 123 ++++++++++------- rllte/hub/bucket.py | 35 ++--- rllte/hub/dmc.py | 96 +++++++------- rllte/hub/minigrid.py | 49 +++---- rllte/hub/procgen.py | 124 +++++++++--------- rllte/xploit/policy/__init__.py | 2 +- .../off_policy_stoch_actor_double_critic.py | 2 +- .../on_policy_shared_actor_critic_lstm.py | 14 +- rllte/xploit/policy/utils.py | 2 +- .../storage/episodic_rollout_storage.py | 8 +- rllte/xplore/reward/ride.py | 2 +- tests/test_env.py | 14 +- tests/test_reward.py | 8 +- tests/test_storage.py | 8 +- 39 files changed, 319 insertions(+), 307 deletions(-) diff --git a/rllte/agent/__init__.py b/rllte/agent/__init__.py index 09403c61..6ff3d8d0 100644 --- a/rllte/agent/__init__.py +++ b/rllte/agent/__init__.py @@ -36,4 +36,4 @@ from .legacy.sacd import SACDiscrete as SACDiscrete from .legacy.td3 import TD3 as TD3 from .ppg import PPG as PPG -from .ppo_lstm import PPO_LSTM as PPO_LSTM \ No newline at end of file +from .ppo_lstm import PPO_LSTM as PPO_LSTM diff --git a/rllte/agent/daac.py b/rllte/agent/daac.py index 31e384ee..45b4fe41 100644 --- a/rllte/agent/daac.py +++ b/rllte/agent/daac.py @@ -95,7 +95,7 @@ def __init__( adv_coef: float = 0.25, max_grad_norm: float = 0.5, discount: float = 0.999, - init_fn: str = "xavier_uniform" + init_fn: str = "xavier_uniform", ) -> None: super().__init__( env=env, @@ -164,7 +164,7 @@ def __init__( storage_size=self.num_steps, num_envs=self.num_envs, batch_size=batch_size, - discount=discount + discount=discount, ) # set all the modules [essential operation!!!] diff --git a/rllte/agent/drac.py b/rllte/agent/drac.py index bb4e4dc8..f675961e 100644 --- a/rllte/agent/drac.py +++ b/rllte/agent/drac.py @@ -157,7 +157,7 @@ def __init__( storage_size=self.num_steps, num_envs=self.num_envs, batch_size=batch_size, - discount=discount + discount=discount, ) # set all the modules [essential operation!!!] diff --git a/rllte/agent/drdaac.py b/rllte/agent/drdaac.py index 70824549..3a9a74f8 100644 --- a/rllte/agent/drdaac.py +++ b/rllte/agent/drdaac.py @@ -170,7 +170,7 @@ def __init__( storage_size=self.num_steps, num_envs=self.num_envs, batch_size=batch_size, - discount=discount + discount=discount, ) # set all the modules [essential operation!!!] diff --git a/rllte/agent/legacy/a2c.py b/rllte/agent/legacy/a2c.py index 05a8b8c1..b556c5c3 100644 --- a/rllte/agent/legacy/a2c.py +++ b/rllte/agent/legacy/a2c.py @@ -139,7 +139,7 @@ def __init__( storage_size=self.num_steps, num_envs=self.num_envs, batch_size=batch_size, - discount=discount + discount=discount, ) # set all the modules [essential operation!!!] diff --git a/rllte/agent/legacy/dqn.py b/rllte/agent/legacy/dqn.py index 73131d02..6326579e 100644 --- a/rllte/agent/legacy/dqn.py +++ b/rllte/agent/legacy/dqn.py @@ -190,4 +190,4 @@ def update(self) -> None: # record metrics self.logger.record("train/q_loss", huber_loss.item()) self.logger.record("train/q", q_values.mean().item()) - self.logger.record("train/target_q", target_q_values.mean().item()) \ No newline at end of file + self.logger.record("train/target_q", target_q_values.mean().item()) diff --git a/rllte/agent/legacy/ppo.py b/rllte/agent/legacy/ppo.py index 2e108167..0a08a393 100644 --- a/rllte/agent/legacy/ppo.py +++ b/rllte/agent/legacy/ppo.py @@ -151,7 +151,7 @@ def __init__( storage_size=self.num_steps, num_envs=self.num_envs, batch_size=batch_size, - discount=discount + discount=discount, ) # set all the modules [essential operation!!!] diff --git a/rllte/agent/legacy/sacd.py b/rllte/agent/legacy/sacd.py index 8533ebfe..98bf6c1d 100644 --- a/rllte/agent/legacy/sacd.py +++ b/rllte/agent/legacy/sacd.py @@ -242,7 +242,7 @@ def update_critic( with th.no_grad(): dist = self.policy.get_dist(next_obs) # deal with situation of 0.0 probabilities - action_probs, log_probs = self.deal_with_zero_probs(dist.probs) # type: ignore[attr-defined] + action_probs, log_probs = self.deal_with_zero_probs(dist.probs) # type: ignore[attr-defined] target_Q1, target_Q2 = self.policy.critic_target(next_obs) target_V = (th.min(target_Q1, target_Q2) - self.alpha.detach() * log_probs) * action_probs # TODO: add time limit mask @@ -278,7 +278,7 @@ def update_actor_and_alpha(self, obs: th.Tensor) -> None: """ # sample actions dist = self.policy.get_dist(obs) - action_probs, log_probs = self.deal_with_zero_probs(dist.probs) # type: ignore[attr-defined] + action_probs, log_probs = self.deal_with_zero_probs(dist.probs) # type: ignore[attr-defined] actor_Q1, actor_Q2 = self.policy.critic(obs) actor_Q = th.min(actor_Q1, actor_Q2) diff --git a/rllte/agent/ppg.py b/rllte/agent/ppg.py index 2a22bc95..00bffe19 100644 --- a/rllte/agent/ppg.py +++ b/rllte/agent/ppg.py @@ -97,7 +97,7 @@ def __init__( num_aux_mini_batch: int = 4, num_aux_grad_accum: int = 1, discount: float = 0.999, - init_fn: str = "xavier_uniform" + init_fn: str = "xavier_uniform", ) -> None: super().__init__( env=env, @@ -162,7 +162,7 @@ def __init__( storage_size=self.num_steps, num_envs=self.num_envs, batch_size=batch_size, - discount=discount + discount=discount, ) # set all the modules [essential operation!!!] diff --git a/rllte/agent/ppo_lstm.py b/rllte/agent/ppo_lstm.py index 59a96338..05fa4b59 100644 --- a/rllte/agent/ppo_lstm.py +++ b/rllte/agent/ppo_lstm.py @@ -31,7 +31,7 @@ from rllte.common.prototype import OnPolicyAgent from rllte.common.type_alias import VecEnv -from rllte.xploit.encoder import EspeholtResidualEncoder, IdentityEncoder, MnihCnnEncoder, PathakCnnEncoder +from rllte.xploit.encoder import IdentityEncoder, MnihCnnEncoder from rllte.xploit.policy import OnPolicySharedActorCriticLSTM from rllte.xploit.storage import EpisodicRolloutStorage from rllte.xplore.distribution import Bernoulli, Categorical, DiagonalGaussian, MultiCategorical @@ -163,7 +163,7 @@ def update(self) -> None: total_policy_loss = [0.0] total_value_loss = [0.0] total_entropy_loss = [0.0] - + for _ in range(self.n_epochs): for batch in self.storage.sample(): done = th.logical_or(batch.terminateds, batch.truncateds) @@ -173,7 +173,7 @@ def update(self) -> None: obs=batch.observations, actions=batch.actions, lstm_state=(self.initial_lstm_state[0][:, batch.env_inds], self.initial_lstm_state[1][:, batch.env_inds]), - done=done + done=done, ) # policy loss part diff --git a/rllte/common/prototype/base_distribution.py b/rllte/common/prototype/base_distribution.py index fc1e525e..6564a3a6 100644 --- a/rllte/common/prototype/base_distribution.py +++ b/rllte/common/prototype/base_distribution.py @@ -41,4 +41,4 @@ def __call__(self, *args, **kwargs) -> Any: """Call the distribution.""" def sample(self, *args, **kwargs) -> th.Tensor: # type: ignore - """Generate samples.""" \ No newline at end of file + """Generate samples.""" diff --git a/rllte/common/prototype/off_policy_agent.py b/rllte/common/prototype/off_policy_agent.py index 27bd04f4..a74f7140 100644 --- a/rllte/common/prototype/off_policy_agent.py +++ b/rllte/common/prototype/off_policy_agent.py @@ -73,7 +73,7 @@ def update(self) -> None: """Update the agent. Implemented by individual algorithms.""" raise NotImplementedError - def train( # noqa: C901 + def train( # noqa: C901 self, num_train_steps: int, init_model_path: Optional[str] = None, @@ -82,7 +82,7 @@ def train( # noqa: C901 save_interval: int = 5000, num_eval_episodes: int = 10, th_compile: bool = False, - anneal_lr: bool = False + anneal_lr: bool = False, ) -> None: """Training function. @@ -123,7 +123,7 @@ def train( # noqa: C901 actions = th.stack([th.as_tensor(self.action_space.sample()) for _ in range(self.num_envs)]) else: actions = self.policy(obs, training=True) - + # update the learning rate if anneal_lr: for key in self.policy.optimizers.keys(): @@ -151,10 +151,7 @@ def train( # noqa: C901 for idx, (term, trunc) in enumerate(zip(terms, truncs)): if term.item() or trunc.item(): # TODO: deal with dict observations - try: - real_next_obs[idx] = th.as_tensor(infos["final_observation"][idx], device=self.device) # type: ignore[index] - except: - pass + real_next_obs[idx] = th.as_tensor(infos["final_observation"][idx], device=self.device) # type: ignore[index] # add new transitions self.storage.add(obs, actions.unsqueeze(-1), rews, terms, truncs, infos, real_next_obs) diff --git a/rllte/common/prototype/on_policy_agent.py b/rllte/common/prototype/on_policy_agent.py index b98f6a97..71e721f5 100644 --- a/rllte/common/prototype/on_policy_agent.py +++ b/rllte/common/prototype/on_policy_agent.py @@ -81,7 +81,7 @@ def train( save_interval: int = 100, num_eval_episodes: int = 10, th_compile: bool = True, - anneal_lr: bool = False + anneal_lr: bool = False, ) -> None: """Training function. @@ -126,7 +126,7 @@ def train( eval_metrics = self.eval(num_eval_episodes) # log to console self.logger.eval(msg=eval_metrics) - + # update the learning rate if anneal_lr: for key in self.policy.optimizers.keys(): @@ -141,10 +141,10 @@ def train( del extra_policy_outputs["lstm_state"] else: actions, extra_policy_outputs = self.policy(obs, training=True) - + # observe rewards and next obs next_obs, rews, terms, truncs, infos = self.env.step(actions) - + if self.use_lstm: done = th.logical_or(terms, truncs) @@ -264,7 +264,7 @@ def eval(self, num_eval_episodes: int) -> Dict[str, Any]: if self.use_lstm: actions, extra_policy_outputs = self.policy(obs, lstm_state, done, training=False) lstm_state = extra_policy_outputs["lstm_state"] - del extra_policy_outputs["lstm_state"] + del extra_policy_outputs["lstm_state"] else: actions, _ = self.policy(obs, training=False) @@ -278,14 +278,14 @@ def eval(self, num_eval_episodes: int) -> Dict[str, Any]: eps_r, eps_l = utils.get_episode_statistics(infos) episode_rewards.extend(eps_r) episode_steps.extend(eps_l) - + if self.use_lstm: lstm_state = ( th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), th.zeros(self.policy.lstm.num_layers, self.num_envs, self.policy.lstm.hidden_size).to(self.device), ) done = th.zeros(self.num_envs, dtype=th.bool, device=self.device) - + # set the current observation obs = next_obs diff --git a/rllte/common/type_alias.py b/rllte/common/type_alias.py index c6bd3932..d4114fc6 100644 --- a/rllte/common/type_alias.py +++ b/rllte/common/type_alias.py @@ -137,7 +137,8 @@ class VanillaRolloutBatch(NamedTuple): truncateds: th.Tensor old_log_probs: th.Tensor adv_targ: th.Tensor - + + class EpisodicRolloutBatch(NamedTuple): observations: th.Tensor actions: th.Tensor diff --git a/rllte/common/utils.py b/rllte/common/utils.py index 5c1bf90f..8c65c01e 100644 --- a/rllte/common/utils.py +++ b/rllte/common/utils.py @@ -130,7 +130,7 @@ def get_episode_statistics(infos: Dict) -> Tuple[List, List]: r: List = [] l: List = [] # to handle with the Atari environments - for info in infos['final_info']: + for info in infos["final_info"]: if info is not None and "episode" in info.keys(): r.extend(info["episode"]["r"].tolist()) l.extend(info["episode"]["l"].tolist()) @@ -191,10 +191,10 @@ def linear_lr_scheduler(optimizer, steps, total_num_steps, initial_lr) -> None: steps (int): Current step. total_num_steps (int): Total number of steps. initial_lr (float): Initial learning rate. - + Returns: None. """ lr = initial_lr - (initial_lr * (steps / float(total_num_steps))) for param_group in optimizer.param_groups: - param_group['lr'] = lr \ No newline at end of file + param_group["lr"] = lr diff --git a/rllte/env/__init__.py b/rllte/env/__init__.py index 588ebee7..f9eb800b 100644 --- a/rllte/env/__init__.py +++ b/rllte/env/__init__.py @@ -55,4 +55,4 @@ from .procgen import make_envpool_procgen_env as make_envpool_procgen_env from .procgen import make_procgen_env as make_procgen_env except Exception: - pass \ No newline at end of file + pass diff --git a/rllte/env/atari/__init__.py b/rllte/env/atari/__init__.py index 65c7df81..ab558c00 100644 --- a/rllte/env/atari/__init__.py +++ b/rllte/env/atari/__init__.py @@ -63,7 +63,7 @@ def make_envpool_atari_env( batch_size=num_envs, seed=seed, episodic_life=True, - reward_clip=True + reward_clip=True, ) if asynchronous: diff --git a/rllte/env/atari/wrappers.py b/rllte/env/atari/wrappers.py index 231c0080..f6557f3a 100644 --- a/rllte/env/atari/wrappers.py +++ b/rllte/env/atari/wrappers.py @@ -194,22 +194,23 @@ def step(self, action: int) -> Tuple[Any, float, bool, bool, Dict]: class RecordEpisodeStatistics4EnvPool(gym.Wrapper): - """Keep track of cumulative rewards and episode lengths. + """Keep track of cumulative rewards and episode lengths. This wrapper is dedicated to EnvPool-based Atari games. Args: env (gym.Env): Environment to wrap. deque_size (int): The size of the buffers :attr:`return_queue` and :attr:`length_queue` - + Returns: RecordEpisodeStatistics4EnvPool instance. """ + def __init__(self, env: gym.Env, deque_size: int = 100) -> None: super().__init__(env) self.num_envs = getattr(env, "num_envs", 1) self.episode_returns: Optional[np.ndarray] = None self.episode_lengths: Optional[np.ndarray] = None - + def reset(self, **kwargs): observations, infos = super().reset(**kwargs) self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) @@ -217,7 +218,7 @@ def reset(self, **kwargs): self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) return observations, infos - + def step(self, actions): observations, rewards, terms, truncs, infos = super().step(actions) self.episode_returns += infos["reward"] @@ -231,8 +232,8 @@ def step(self, actions): infos["episode"]["l"] = self.returned_episode_lengths for idx, d in enumerate(terms): - if not d or infos["lives"][idx] != 0: - infos["episode"]["r"][idx] = 0 - infos["episode"]["l"][idx] = 0 + if not d or infos["lives"][idx] != 0: + infos["episode"]["r"][idx] = 0 + infos["episode"]["l"][idx] = 0 - return observations, rewards, terms, truncs, infos \ No newline at end of file + return observations, rewards, terms, truncs, infos diff --git a/rllte/env/dmc/wrappers.py b/rllte/env/dmc/wrappers.py index 5b784176..05948e42 100644 --- a/rllte/env/dmc/wrappers.py +++ b/rllte/env/dmc/wrappers.py @@ -227,7 +227,7 @@ def _extract_min_max(s): mins, maxs = [], [] for s in spec: - mn, mx = _extract_min_max(s) # type: ignore + mn, mx = _extract_min_max(s) # type: ignore mins.append(mn) maxs.append(mx) low = np.concatenate(mins, axis=0).astype(dtype) diff --git a/rllte/env/minigrid/__init__.py b/rllte/env/minigrid/__init__.py index d1dec9a6..bb14c2a1 100644 --- a/rllte/env/minigrid/__init__.py +++ b/rllte/env/minigrid/__init__.py @@ -145,4 +145,4 @@ def _thunk(): envs = SyncVectorEnv(envs) envs = RecordEpisodeStatistics(envs) - return Gymnasium2Torch(envs, device=device) \ No newline at end of file + return Gymnasium2Torch(envs, device=device) diff --git a/rllte/env/testing/box.py b/rllte/env/testing/box.py index dd33f40b..61ac5cec 100644 --- a/rllte/env/testing/box.py +++ b/rllte/env/testing/box.py @@ -118,15 +118,18 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A return obs, reward, terminated, truncated, info + class DictEnv(gym.Env): """Environment with dict-based observation space and `Box` action space for testing.""" def __init__(self) -> None: super().__init__() - self.observation_space = gym.spaces.Dict(spaces={ - "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), - "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), - }) + self.observation_space = gym.spaces.Dict( + spaces={ + "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), + "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), + } + ) self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(7,), dtype=np.float32) def reset(self, seed: Optional[int] = None, options=Optional[Dict[str, Any]]) -> Tuple[Any, Dict[str, Any]]: @@ -161,6 +164,7 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A return obs, reward, terminated, truncated, info + def make_box_env( env_id: str = "StateObsEnv", num_envs: int = 1, device: str = "cpu", seed: int = 0, asynchronous: bool = True ) -> Gymnasium2Torch: diff --git a/rllte/env/testing/discrete.py b/rllte/env/testing/discrete.py index d890817a..202ff482 100644 --- a/rllte/env/testing/discrete.py +++ b/rllte/env/testing/discrete.py @@ -114,15 +114,18 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A return obs, reward, terminated, truncated, info + class DictEnv(gym.Env): """Environment with dict-based observation space and `Discrete` action space for testing.""" def __init__(self) -> None: super().__init__() - self.observation_space = gym.spaces.Dict(spaces={ - "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), - "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), - }) + self.observation_space = gym.spaces.Dict( + spaces={ + "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), + "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), + } + ) self.action_space = gym.spaces.Discrete(n=7) def reset(self, seed: Optional[int] = None, options=Optional[Dict[str, Any]]) -> Tuple[Any, Dict[str, Any]]: @@ -157,6 +160,7 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A return obs, reward, terminated, truncated, info + def make_discrete_env( env_id: str = "StateObsEnv", num_envs: int = 1, device: str = "cpu", seed: int = 0, asynchronous: bool = True ) -> Gymnasium2Torch: diff --git a/rllte/env/testing/multibinary.py b/rllte/env/testing/multibinary.py index d7e93ee3..173e98eb 100644 --- a/rllte/env/testing/multibinary.py +++ b/rllte/env/testing/multibinary.py @@ -114,15 +114,18 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A return obs, reward, terminated, truncated, info + class DictEnv(gym.Env): """Environment with dict-based observation space and `MultiBinary` action space for testing.""" def __init__(self) -> None: super().__init__() - self.observation_space = gym.spaces.Dict(spaces={ - "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), - "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), - }) + self.observation_space = gym.spaces.Dict( + spaces={ + "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), + "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), + } + ) self.action_space = gym.spaces.MultiBinary(n=3) def reset(self, seed: Optional[int] = None, options=Optional[Dict[str, Any]]) -> Tuple[Any, Dict[str, Any]]: @@ -157,6 +160,7 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A return obs, reward, terminated, truncated, info + def make_multibinary_env( env_id: str = "StateObsEnv", num_envs: int = 1, device: str = "cpu", seed: int = 0, asynchronous: bool = True ) -> Gymnasium2Torch: diff --git a/rllte/env/testing/multidiscrete.py b/rllte/env/testing/multidiscrete.py index b9c1b04a..f27283b6 100644 --- a/rllte/env/testing/multidiscrete.py +++ b/rllte/env/testing/multidiscrete.py @@ -120,10 +120,12 @@ class DictEnv(gym.Env): def __init__(self) -> None: super().__init__() - self.observation_space = gym.spaces.Dict(spaces={ - "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), - "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), - }) + self.observation_space = gym.spaces.Dict( + spaces={ + "image": gym.spaces.Box(low=-1.0, high=1.0, shape=(3, 84, 84), dtype=np.float32), + "state": gym.spaces.Box(low=-1.0, high=1.0, shape=(49,), dtype=np.float32), + } + ) self.action_space = gym.spaces.MultiDiscrete(nvec=(2, 3, 4)) def reset(self, seed: Optional[int] = None, options=Optional[Dict[str, Any]]) -> Tuple[Any, Dict[str, Any]]: diff --git a/rllte/hub/__init__.py b/rllte/hub/__init__.py index eb5d97df..0b8233f3 100644 --- a/rllte/hub/__init__.py +++ b/rllte/hub/__init__.py @@ -26,4 +26,4 @@ from .atari import Atari as Atari from .dmc import DMControl as DMControl from .minigrid import MiniGrid as MiniGrid -from .procgen import Procgen as Procgen \ No newline at end of file +from .procgen import Procgen as Procgen diff --git a/rllte/hub/atari.py b/rllte/hub/atari.py index c51b7de4..ef3f3c0b 100644 --- a/rllte/hub/atari.py +++ b/rllte/hub/atari.py @@ -23,7 +23,7 @@ # ============================================================================= -from typing import Callable, Dict +from typing import Dict import numpy as np import torch as th @@ -44,39 +44,87 @@ class Atari(Bucket): Number of seeds: 10 Added algorithms: [PPO] """ + def __init__(self) -> None: super().__init__() - self.sup_env = ['Alien-v5', 'Amidar-v5', 'Assault-v5', 'Asterix-v5', 'Asteroids-v5', 'Atlantis-v5', 'YarsRevenge-v5', - 'BankHeist-v5', 'BattleZone-v5', 'BeamRider-v5', 'Berzerk-v5', 'Bowling-v5', 'Boxing-v5', 'Breakout-v5', - 'Centipede-v5', 'ChopperCommand-v5', 'CrazyClimber-v5', 'Defender-v5', 'DemonAttack-v5', 'DoubleDunk-v5', 'Zaxxon-v5', - 'Enduro-v5', 'FishingDerby-v5', 'Freeway-v5', 'Frostbite-v5', 'Gopher-v5', 'Gravitar-v5', 'Hero-v5', - 'IceHockey-v5', 'Jamesbond-v5', 'Kangaroo-v5', 'Krull-v5', 'KungFuMaster-v5', 'MontezumaRevenge-v5', 'Pitfall-v5', - 'PrivateEye-v5', 'Qbert-v5', 'Riverraid-v5', 'RoadRunner-v5', 'Robotank-v5', 'Seaquest-v5', 'Phoenix-v5', 'Pong-v5', - 'Skiing-v5', 'Solaris-v5', 'SpaceInvaders-v5', 'StarGunner-v5', 'Surround-v5', 'Tennis-v5', 'TimePilot-v5', - 'Tutankham-v5', 'UpNDown-v5', 'Venture-v5', 'VideoPinball-v5', 'WizardOfWor-v5', 'MsPacman-v5', 'NameThisGame-v5' - ] - self.sup_algo = ['ppo'] + self.sup_env = [ + "Alien-v5", + "Amidar-v5", + "Assault-v5", + "Asterix-v5", + "Asteroids-v5", + "Atlantis-v5", + "YarsRevenge-v5", + "BankHeist-v5", + "BattleZone-v5", + "BeamRider-v5", + "Berzerk-v5", + "Bowling-v5", + "Boxing-v5", + "Breakout-v5", + "Centipede-v5", + "ChopperCommand-v5", + "CrazyClimber-v5", + "Defender-v5", + "DemonAttack-v5", + "DoubleDunk-v5", + "Zaxxon-v5", + "Enduro-v5", + "FishingDerby-v5", + "Freeway-v5", + "Frostbite-v5", + "Gopher-v5", + "Gravitar-v5", + "Hero-v5", + "IceHockey-v5", + "Jamesbond-v5", + "Kangaroo-v5", + "Krull-v5", + "KungFuMaster-v5", + "MontezumaRevenge-v5", + "Pitfall-v5", + "PrivateEye-v5", + "Qbert-v5", + "Riverraid-v5", + "RoadRunner-v5", + "Robotank-v5", + "Seaquest-v5", + "Phoenix-v5", + "Pong-v5", + "Skiing-v5", + "Solaris-v5", + "SpaceInvaders-v5", + "StarGunner-v5", + "Surround-v5", + "Tennis-v5", + "TimePilot-v5", + "Tutankham-v5", + "UpNDown-v5", + "Venture-v5", + "VideoPinball-v5", + "WizardOfWor-v5", + "MsPacman-v5", + "NameThisGame-v5", + ] + self.sup_algo = ["ppo"] def load_scores(self, env_id: str, agent: str) -> np.ndarray: """Returns final performance. - + Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Test scores data array with shape (N_SEEDS, N_POINTS). """ self.is_available(env_id=env_id, agent=agent.lower()) - scores_file = f'{agent.lower()}_atari_{env_id}_scores.npy' + scores_file = f"{agent.lower()}_atari_{env_id}_scores.npy" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=scores_file, - subfolder="atari/scores" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=scores_file, subfolder="atari/scores" ) return np.load(file) @@ -87,7 +135,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Learning curves data with structure: curves @@ -96,26 +144,18 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: """ self.is_available(env_id=env_id, agent=agent.lower()) - curves_file = f'{agent.lower()}_atari_{env_id}_curves.npz' + curves_file = f"{agent.lower()}_atari_{env_id}_curves.npz" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=curves_file, - subfolder="atari/curves" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=curves_file, subfolder="atari/curves" ) curves_dict = np.load(file, allow_pickle=True) curves_dict = dict(curves_dict) return curves_dict - - def load_models(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> nn.Module: + + def load_models(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> nn.Module: """Load the model from the hub. Args: @@ -128,20 +168,15 @@ def load_models(self, The loaded model. """ self.is_available(env_id=env_id, agent=agent.lower()) - + model_file = f"{agent.lower()}_atari_{env_id}_seed_{seed}.pth" subfolder = f"atari/{agent}" file = hf_hub_download(repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=model_file, subfolder=subfolder) model = th.load(file, map_location=device) return model.eval() - - def load_apis(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> BaseAgent: + + def load_apis(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> BaseAgent: """Load the a training API. Args: @@ -156,7 +191,7 @@ def load_apis(self, if agent.lower() == "ppo": # The following hyperparameters are from the repository: # https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail - # Since the asynchronous mode achieved much lower training performance than the synchronous mode, + # Since the asynchronous mode achieved much lower training performance than the synchronous mode, # we recommend using the synchronous mode currently. envs = make_envpool_atari_env(env_id=env_id, num_envs=8, device=device, seed=seed, asynchronous=False) eval_envs = make_envpool_atari_env(env_id=env_id, num_envs=8, device=device, seed=seed, asynchronous=False) @@ -179,7 +214,7 @@ def load_apis(self, ent_coef=0.01, max_grad_norm=0.5, discount=0.99, - init_fn="orthogonal" + init_fn="orthogonal", ) elif agent.lower() == "a2c": # The following hyperparameters are from the repository: @@ -187,7 +222,7 @@ def load_apis(self, envs = make_envpool_atari_env(env_id=env_id, num_envs=16, device=device, seed=seed) eval_envs = make_envpool_atari_env(env_id=env_id, num_envs=16, device=device, seed=seed, asynchronous=False) - api = A2C( # type: ignore[assignment] + api = A2C( # type: ignore[assignment] env=envs, eval_env=eval_envs, tag=f"a2c_atari_{env_id}_seed_{seed}", @@ -210,7 +245,7 @@ def load_apis(self, # https://github.com/facebookresearch/torchbeast envs = make_atari_env(env_id=env_id, device=device, seed=seed, num_envs=45, asynchronous=False) eval_envs = make_atari_env(env_id=env_id, device=device, seed=seed, num_envs=1, asynchronous=False) - self.agent = IMPALA( # type: ignore[assignment] + self.agent = IMPALA( # type: ignore[assignment] env=envs, eval_env=eval_envs, tag=f"impala_atari_{env_id}_seed_{seed}", @@ -225,4 +260,4 @@ def load_apis(self, else: raise NotImplementedError(f"Agent {agent} is not supported currently, available agents are: [A2C, PPO, IMPALA].") - return api \ No newline at end of file + return api diff --git a/rllte/hub/bucket.py b/rllte/hub/bucket.py index e8b54570..7d00cf07 100644 --- a/rllte/hub/bucket.py +++ b/rllte/hub/bucket.py @@ -24,7 +24,7 @@ from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional +from typing import Dict, List, Optional import numpy as np from torch import nn @@ -34,31 +34,30 @@ class Bucket(ABC): """Bucket class for storing scores, learning curves, and models.""" + def __init__(self) -> None: super().__init__() self.sup_env: List = [] self.sup_algo: List = [] - def is_available(self, env_id: str, agent: Optional[str] = None) -> None: """Check if the dataset is available.""" - assert env_id in self.sup_env and agent in self.sup_algo, \ - f"Datasets for `{env_id}` and `{agent}` are not available currently!" - + assert ( + env_id in self.sup_env and agent in self.sup_algo + ), f"Datasets for `{env_id}` and `{agent}` are not available currently!" @abstractmethod def load_scores(self, env_id: str, agent: str) -> np.ndarray: """Returns final performance. - + Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Test scores data array with shape (N_SEEDS, N_POINTS). """ - @abstractmethod def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: @@ -67,21 +66,16 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Learning curves data with structure: curves ├── train: np.ndarray(shape=(N_SEEDS, N_POINTS)) └── eval: np.ndarray(shape=(N_SEEDS, N_POINTS)) """ - + @abstractmethod - def load_models(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> nn.Module: + def load_models(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> nn.Module: """Load the model from the hub. Args: @@ -95,12 +89,7 @@ def load_models(self, """ @abstractmethod - def load_apis(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> BaseAgent: + def load_apis(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> BaseAgent: """Load the a training API. Args: @@ -111,4 +100,4 @@ def load_apis(self, Returns: The loaded API. - """ \ No newline at end of file + """ diff --git a/rllte/hub/dmc.py b/rllte/hub/dmc.py index 65048d7a..94aaf12f 100644 --- a/rllte/hub/dmc.py +++ b/rllte/hub/dmc.py @@ -35,10 +35,11 @@ from rllte.env import make_dmc_env from rllte.hub.bucket import Bucket -# cheetah_run quadruped_walk quadruped_run walker_walk walker_run hopper_hop arcobot_swingup cup_catch +# cheetah_run quadruped_walk quadruped_run walker_walk walker_run hopper_hop arcobot_swingup cup_catch # cartpole_balance cartpole_balance_sparse cartpole_swingup cartpole_swingup_sparse finger_spin finger_turn_easy # finger_turn_hard fish_swim fish_upright hopper_stand pendulum_swingup quadruped_run reacher_easy reacher_hard swimmer_swimmer6 swimmer_swimmer15 + class DMControl(Bucket): """Scores and learning cures of various RL algorithms on the full DeepMind Control Suite benchmark. @@ -48,53 +49,70 @@ class DMControl(Bucket): Number of seeds: 10 Added algorithms: [SAC, DrQ-v2] """ + def __init__(self) -> None: super().__init__() - self.sup_env = ['acrobot_swingup', 'cartpole_balance', 'cartpole_balance_sparse', - 'cartpole_swingup', 'cartpole_swingup_sparse', 'cheetah_run', - 'cup_catch', 'finger_spin', 'finger_turn_easy', - 'finger_turn_hard', 'fish_swim', 'fish_upright', - 'hopper_hop', 'hopper_stand', 'pendulum_swingup', - 'quadruped_run', 'quadruped_walk', 'reacher_easy', - 'reacher_hard', 'swimmer_swimmer6', 'swimmer_swimmer15', - 'walker_run', 'walker_walk', 'walker_stand', - 'humanoid_walk', 'humanoid_run', 'humanoid_stand' - ] - self.sup_algo = ['sac'] - + self.sup_env = [ + "acrobot_swingup", + "cartpole_balance", + "cartpole_balance_sparse", + "cartpole_swingup", + "cartpole_swingup_sparse", + "cheetah_run", + "cup_catch", + "finger_spin", + "finger_turn_easy", + "finger_turn_hard", + "fish_swim", + "fish_upright", + "hopper_hop", + "hopper_stand", + "pendulum_swingup", + "quadruped_run", + "quadruped_walk", + "reacher_easy", + "reacher_hard", + "swimmer_swimmer6", + "swimmer_swimmer15", + "walker_run", + "walker_walk", + "walker_stand", + "humanoid_walk", + "humanoid_run", + "humanoid_stand", + ] + self.sup_algo = ["sac"] + def get_obs_type(self, agent: str) -> str: """Returns the observation type of the agent. - + Args: agent (str): Agent name. - + Returns: Observation type. """ - obs_type = 'state' if agent in ['sac'] else 'pixel' + obs_type = "state" if agent in ["sac"] else "pixel" return obs_type def load_scores(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: """Returns final performance. - + Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Test scores data array with shape (N_SEEDS, N_POINTS). """ self.is_available(env_id=env_id, agent=agent.lower()) obs_type = self.get_obs_type(agent=agent.lower()) - scores_file = f'{agent.lower()}_dmc_{obs_type}_{env_id}_scores.npy' + scores_file = f"{agent.lower()}_dmc_{obs_type}_{env_id}_scores.npy" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=scores_file, - subfolder="dmc/scores" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=scores_file, subfolder="dmc/scores" ) return np.load(file) @@ -106,7 +124,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: env_id (str): Environment ID. agent_id (str): Agent name. obs_type (str): A type from ['state', 'pixel']. - + Returns: Learning curves data with structure: curves @@ -116,13 +134,10 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: self.is_available(env_id=env_id, agent=agent.lower()) obs_type = self.get_obs_type(agent=agent.lower()) - curves_file = f'{agent.lower()}_dmc_{obs_type}_{env_id}_curves.npz' + curves_file = f"{agent.lower()}_dmc_{obs_type}_{env_id}_curves.npz" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=curves_file, - subfolder="dmc/curves" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=curves_file, subfolder="dmc/curves" ) curves_dict = np.load(file, allow_pickle=True) @@ -130,12 +145,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: return curves_dict - def load_models(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> nn.Module: + def load_models(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> nn.Module: """Load the model from the hub. Args: @@ -157,13 +167,7 @@ def load_models(self, return model.eval() - - def load_apis(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> BaseAgent: + def load_apis(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> BaseAgent: """Load the a training API. Args: @@ -213,7 +217,7 @@ def load_apis(self, visualize_reward=False, frame_stack=3, action_repeat=2, - asynchronous=False + asynchronous=False, ) eval_envs = make_dmc_env( env_id=env_id, @@ -224,10 +228,10 @@ def load_apis(self, visualize_reward=False, frame_stack=3, action_repeat=2, - asynchronous=False + asynchronous=False, ) # create agent - api = DrQv2( # type: ignore[assignment] + api = DrQv2( # type: ignore[assignment] env=envs, eval_env=eval_envs, tag=f"drqv2_dmc_pixel_{env_id}_seed_{seed}", @@ -247,4 +251,4 @@ def load_apis(self, f"Agent {agent} is not supported currently, available agents are: [SAC, DDPG, TD3, DrQv2]." ) - return api \ No newline at end of file + return api diff --git a/rllte/hub/minigrid.py b/rllte/hub/minigrid.py index c72f5cd2..4cd7e12a 100644 --- a/rllte/hub/minigrid.py +++ b/rllte/hub/minigrid.py @@ -44,31 +44,29 @@ class MiniGrid(Bucket): Number of seeds: 10 Added algorithms: [A2C] """ + def __init__(self) -> None: super().__init__() - self.sup_env = ['Empty-6x6-v0'] - self.sup_algo = ['ppo'] + self.sup_env = ["Empty-6x6-v0"] + self.sup_algo = ["ppo"] def load_scores(self, env_id: str, agent: str) -> np.ndarray: """Returns final performance. - + Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Test scores data array with shape (N_SEEDS, N_POINTS). """ self.is_available(env_id=env_id, agent=agent.lower()) - scores_file = f'{agent.lower()}_minigrid_{env_id}_scores.npy' + scores_file = f"{agent.lower()}_minigrid_{env_id}_scores.npy" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=scores_file, - subfolder="minigrid/scores" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=scores_file, subfolder="minigrid/scores" ) return np.load(file) @@ -79,7 +77,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Learning curves data with structure: curves @@ -88,13 +86,10 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: """ self.is_available(env_id=env_id, agent=agent.lower()) - curves_file = f'{agent.lower()}_minigrid_{env_id}_curves.npz' + curves_file = f"{agent.lower()}_minigrid_{env_id}_curves.npz" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=curves_file, - subfolder="minigrid/curves" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=curves_file, subfolder="minigrid/curves" ) curves_dict = np.load(file, allow_pickle=True) @@ -102,13 +97,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: return curves_dict - - def load_models(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> nn.Module: + def load_models(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> nn.Module: """Load the model from the hub. Args: @@ -121,21 +110,15 @@ def load_models(self, The loaded model. """ self.is_available(env_id=env_id, agent=agent.lower()) - + model_file = f"{agent.lower()}_minigrid_{env_id}_seed_{seed}.pth" subfolder = f"minigrid/{agent}" file = hf_hub_download(repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=model_file, subfolder=subfolder) model = th.load(file, map_location=device) return model.eval() - - - def load_apis(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> BaseAgent: + + def load_apis(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> BaseAgent: """Load the a training API. Args: @@ -178,7 +161,7 @@ def load_apis(self, # https://github.com/lcswillems/rl-starter-files envs = make_minigrid_env(env_id=env_id, num_envs=1, device=device, seed=seed) eval_envs = make_minigrid_env(env_id=env_id, num_envs=1, device=device, seed=seed) - api = A2C( # type: ignore[assignment] + api = A2C( # type: ignore[assignment] env=envs, eval_env=eval_envs, tag=f"a2c_{env_id}_seed_{seed}", @@ -199,4 +182,4 @@ def load_apis(self, else: raise NotImplementedError(f"Agent {agent} is not supported currently, available agents are: [A2C, PPO].") - return api \ No newline at end of file + return api diff --git a/rllte/hub/procgen.py b/rllte/hub/procgen.py index d07677b6..8dd16ea6 100644 --- a/rllte/hub/procgen.py +++ b/rllte/hub/procgen.py @@ -45,36 +45,46 @@ class Procgen(Bucket): Number of seeds: 10 Added algorithms: [PPO] """ + def __init__(self) -> None: super().__init__() - self.sup_env = ['bigfish', 'bossfight', 'caveflyer', 'chaser', - 'climber', 'coinrun', 'dodgeball', 'fruitbot', - 'heist', 'jumper', 'leaper', 'maze', - 'miner', 'ninja', 'plunder', 'starpilot' - ] - self.sup_algo = ['ppo'] - + self.sup_env = [ + "bigfish", + "bossfight", + "caveflyer", + "chaser", + "climber", + "coinrun", + "dodgeball", + "fruitbot", + "heist", + "jumper", + "leaper", + "maze", + "miner", + "ninja", + "plunder", + "starpilot", + ] + self.sup_algo = ["ppo"] def load_scores(self, env_id: str, agent: str) -> np.ndarray: """Returns final performance. - + Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Test scores data array with shape (N_SEEDS, N_POINTS). """ self.is_available(env_id=env_id, agent=agent.lower()) - scores_file = f'{agent.lower()}_procgen_{env_id}_scores.npy' + scores_file = f"{agent.lower()}_procgen_{env_id}_scores.npy" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=scores_file, - subfolder="procgen/scores" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=scores_file, subfolder="procgen/scores" ) return np.load(file) @@ -85,7 +95,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: Args: env_id (str): Environment ID. agent_id (str): Agent name. - + Returns: Learning curves data with structure: curves @@ -94,13 +104,10 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: """ self.is_available(env_id=env_id, agent=agent.lower()) - curves_file = f'{agent.lower()}_procgen_{env_id}_curves.npz' + curves_file = f"{agent.lower()}_procgen_{env_id}_curves.npz" file = hf_hub_download( - repo_id="RLE-Foundation/rllte-hub", - repo_type="model", - filename=curves_file, - subfolder="procgen/curves" + repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=curves_file, subfolder="procgen/curves" ) curves_dict = np.load(file, allow_pickle=True) @@ -108,12 +115,7 @@ def load_curves(self, env_id: str, agent: str) -> Dict[str, np.ndarray]: return curves_dict - def load_models(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> nn.Module: + def load_models(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> nn.Module: """Load the model from the hub. Args: @@ -126,7 +128,7 @@ def load_models(self, The loaded model. """ self.is_available(env_id=env_id, agent=agent.lower()) - + model_file = f"{agent.lower()}_procgen_{env_id}_seed_{seed}.pth" subfolder = f"procgen/{agent}" file = hf_hub_download(repo_id="RLE-Foundation/rllte-hub", repo_type="model", filename=model_file, subfolder=subfolder) @@ -134,13 +136,7 @@ def load_models(self, return model.eval() - - def load_apis(self, - env_id: str, - agent: str, - seed: int, - device: str = "cpu" - ) -> BaseAgent: + def load_apis(self, env_id: str, agent: str, seed: int, device: str = "cpu") -> BaseAgent: """Load the a training API. Args: @@ -153,27 +149,27 @@ def load_apis(self, The loaded API. """ envs = make_envpool_procgen_env( - env_id=env_id, - num_envs=64, - device=device, - seed=seed, - gamma=0.99, - num_levels=200, - start_level=0, - distribution_mode="easy", - asynchronous=False - ) + env_id=env_id, + num_envs=64, + device=device, + seed=seed, + gamma=0.99, + num_levels=200, + start_level=0, + distribution_mode="easy", + asynchronous=False, + ) eval_envs = make_envpool_procgen_env( - env_id=env_id, - num_envs=1, - device=device, - seed=seed, - gamma=0.99, - num_levels=0, - start_level=0, - distribution_mode="easy", - asynchronous=False, - ) + env_id=env_id, + num_envs=1, + device=device, + seed=seed, + gamma=0.99, + num_levels=0, + start_level=0, + distribution_mode="easy", + asynchronous=False, + ) feature_dim = 256 if agent.lower() == "ppo": @@ -201,28 +197,28 @@ def load_apis(self, elif agent.lower() == "daac": # Best hyperparameters for DAAC reported in # https://github.com/rraileanu/idaac/blob/main/hyperparams.py - if env_id in ['plunder', 'chaser']: + if env_id in ["plunder", "chaser"]: value_epochs = 1 else: value_epochs = 9 - - if env_id in ['miner', 'bigfish', 'dodgeball']: + + if env_id in ["miner", "bigfish", "dodgeball"]: value_freq = 32 - elif env_id == 'plunder': + elif env_id == "plunder": value_freq = 8 else: value_freq = 1 - - if env_id == 'plunder': + + if env_id == "plunder": adv_coef = 0.3 - elif env_id == 'chaser': + elif env_id == "chaser": adv_coef = 0.15 - elif env_id in ['climber', 'bigfish']: + elif env_id in ["climber", "bigfish"]: adv_coef = 0.05 else: adv_coef = 0.25 - api = DAAC( # type: ignore[assignment] + api = DAAC( # type: ignore[assignment] env=envs, eval_env=eval_envs, tag=f"daac_procgen_{env_id}_seed_{seed}", @@ -251,4 +247,4 @@ def load_apis(self, encoder = EspeholtResidualEncoder(observation_space=envs.observation_space, feature_dim=feature_dim) api.set(encoder=encoder) - return api \ No newline at end of file + return api diff --git a/rllte/xploit/policy/__init__.py b/rllte/xploit/policy/__init__.py index d8939d92..2c3d24e9 100644 --- a/rllte/xploit/policy/__init__.py +++ b/rllte/xploit/policy/__init__.py @@ -30,4 +30,4 @@ from .off_policy_stoch_actor_double_critic import OffPolicyStochActorDoubleCritic as OffPolicyStochActorDoubleCritic from .on_policy_decoupled_actor_critic import OnPolicyDecoupledActorCritic as OnPolicyDecoupledActorCritic from .on_policy_shared_actor_critic import OnPolicySharedActorCritic as OnPolicySharedActorCritic -from .on_policy_shared_actor_critic_lstm import OnPolicySharedActorCriticLSTM as OnPolicySharedActorCriticLSTM \ No newline at end of file +from .on_policy_shared_actor_critic_lstm import OnPolicySharedActorCriticLSTM as OnPolicySharedActorCriticLSTM diff --git a/rllte/xploit/policy/off_policy_stoch_actor_double_critic.py b/rllte/xploit/policy/off_policy_stoch_actor_double_critic.py index 5b63f750..07a10510 100644 --- a/rllte/xploit/policy/off_policy_stoch_actor_double_critic.py +++ b/rllte/xploit/policy/off_policy_stoch_actor_double_critic.py @@ -83,7 +83,7 @@ def __init__( # build actor and critic actor_kwargs = {"action_dim": self.policy_action_dim, "hidden_dim": self.hidden_dim, "feature_dim": self.feature_dim} if self.action_type == "Box": - actor_kwargs["log_std_range"] = log_std_range # type: ignore[assignment] + actor_kwargs["log_std_range"] = log_std_range # type: ignore[assignment] self.actor = get_off_policy_actor(action_type=self.action_type, actor_kwargs=actor_kwargs) diff --git a/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py b/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py index 70d48c6f..91ead5ee 100644 --- a/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py +++ b/rllte/xploit/policy/on_policy_shared_actor_critic_lstm.py @@ -92,7 +92,7 @@ def __init__( nn.init.constant_(param, 0) elif "weight" in name: nn.init.orthogonal_(param, 1.0) - + # build actor and critic actor_kwargs = dict( obs_shape=self.obs_shape, @@ -171,7 +171,9 @@ def get_states(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor): new_hidden = th.flatten(th.cat(new_hidden), 0, 1) return new_hidden, lstm_state - def forward(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor, training: bool = True) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: + def forward( + self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor, training: bool = True + ) -> Tuple[th.Tensor, Dict[str, th.Tensor]]: """Get actions and estimated values for observations. Args: @@ -206,7 +208,9 @@ def get_value(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> t """ return self.critic(self.get_states(obs, lstm_state, done)[0]) - def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def evaluate_actions( + self, obs: th.Tensor, actions: th.Tensor, lstm_state: th.Tensor, done: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """Evaluate actions according to the current policy given the observations. Args: @@ -238,7 +242,9 @@ def get_policy_outputs(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Ten policy_outputs = self.actor.get_policy_outputs(h) return th.cat(policy_outputs, dim=1) - def get_dist_and_aux_value(self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor) -> Tuple[Distribution, th.Tensor, th.Tensor]: + def get_dist_and_aux_value( + self, obs: th.Tensor, lstm_state: th.Tensor, done: th.Tensor + ) -> Tuple[Distribution, th.Tensor, th.Tensor]: """Get probs and auxiliary estimated values for auxiliary phase update. Args: diff --git a/rllte/xploit/policy/utils.py b/rllte/xploit/policy/utils.py index dd00fd5c..9a02a980 100644 --- a/rllte/xploit/policy/utils.py +++ b/rllte/xploit/policy/utils.py @@ -548,7 +548,7 @@ def get_off_policy_actor(action_type: str, actor_kwargs: Dict) -> Union[OffPolic if action_type in ["Discrete"]: actor_class = OffPolicyDiscreteActor elif action_type == "Box": - actor_class = OffPolicyBoxActor # type: ignore[assignment] + actor_class = OffPolicyBoxActor # type: ignore[assignment] else: raise NotImplementedError(f"Unsupported action type {action_type}!") return actor_class(**actor_kwargs) diff --git a/rllte/xploit/storage/episodic_rollout_storage.py b/rllte/xploit/storage/episodic_rollout_storage.py index 18352589..04261109 100644 --- a/rllte/xploit/storage/episodic_rollout_storage.py +++ b/rllte/xploit/storage/episodic_rollout_storage.py @@ -35,7 +35,7 @@ class EpisodicRolloutStorage(BaseStorage): - """Episodic rollout storage for on-policy algorithms that use an LSTM. + """Episodic rollout storage for on-policy algorithms that use an LSTM. It is the same as VanillaRolloutStorage but samples enitre trajectories instead of batches of different steps. Args: @@ -161,8 +161,8 @@ def compute_returns_and_advantages(self, last_values: th.Tensor) -> None: self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-5) def sample(self) -> Generator: - """ - Choose a minibatch of environment indices and sample the entire rollout for those minibatches. + """ + Choose a minibatch of environment indices and sample the entire rollout for those minibatches. By not sampling uniform transitions, we can now train an LSTM model on entire trajectories """ assert self.full, "Cannot sample when the storage is not full!" @@ -193,4 +193,4 @@ def sample(self) -> Generator: old_log_probs=b_log[mb_inds], adv_targ=b_adv[mb_inds], env_inds=indices, - ) \ No newline at end of file + ) diff --git a/rllte/xplore/reward/ride.py b/rllte/xplore/reward/ride.py index 837f6211..089989f6 100644 --- a/rllte/xplore/reward/ride.py +++ b/rllte/xplore/reward/ride.py @@ -330,6 +330,6 @@ def update(self, samples: Dict) -> None: self.encoder_opt.step() self.im_opt.step() self.fm_opt.step() - + def add(self, samples: Dict) -> None: """Add new samples to the intrinsic reward module.""" diff --git a/tests/test_env.py b/tests/test_env.py index a6a0b907..b0057ac2 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -13,14 +13,7 @@ @pytest.mark.parametrize( "env_cls", - [ - make_atari_env, - make_minigrid_env, - make_procgen_env, - make_dmc_env, - make_envpool_atari_env, - make_envpool_procgen_env - ], + [make_atari_env, make_minigrid_env, make_procgen_env, make_dmc_env, make_envpool_atari_env, make_envpool_procgen_env], ) @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_discrete_env(env_cls, device): @@ -28,7 +21,7 @@ def test_discrete_env(env_cls, device): if env_cls in [make_procgen_env]: env = env_cls(device=device, num_envs=num_envs) else: - # when set `asynchronous=True` for all the envs, + # when set `asynchronous=True` for all the envs, # the test will raise an EOF error env = env_cls(device=device, num_envs=num_envs, asynchronous=False) _ = env.reset() @@ -38,8 +31,7 @@ def test_discrete_env(env_cls, device): for _ in range(10): action = env.action_space.sample() - if env_cls in [make_atari_env, make_minigrid_env, make_procgen_env, - make_envpool_atari_env, make_envpool_procgen_env]: + if env_cls in [make_atari_env, make_minigrid_env, make_procgen_env, make_envpool_atari_env, make_envpool_procgen_env]: action = th.randint(0, env.action_space.n, (num_envs,)).to(device) else: action = th.rand(size=(num_envs, env.action_space.shape[0])).to(device) diff --git a/tests/test_reward.py b/tests/test_reward.py index bb828008..2082fff2 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -32,13 +32,7 @@ def test_reward(reward, env_cls, device): action = th.randint(0, 2, (num_steps, num_envs, env.action_space.n)).float().to(device) for i in range(num_steps): - irs.add( - samples={ - 'obs': obs[i], - 'actions': action[i], - 'next_obs': obs[i] - } - ) + irs.add(samples={"obs": obs[i], "actions": action[i], "next_obs": obs[i]}) samples = { "obs": obs, diff --git a/tests/test_storage.py b/tests/test_storage.py index d9ee6edf..476589aa 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -22,7 +22,7 @@ VanillaRolloutStorage, DictReplayStorage, DictRolloutStorage, - HerReplayStorage + HerReplayStorage, ], ) @pytest.mark.parametrize("device", ["cpu", "cuda"]) @@ -45,7 +45,7 @@ def test_storage(storage_cls, device): device=device, num_envs=num_envs, storage_size=num_steps, - batch_size=batch_size + batch_size=batch_size, ) elif storage_cls is HerReplayStorage: storage = storage_cls( @@ -54,7 +54,7 @@ def test_storage(storage_cls, device): device=device, num_envs=num_envs, batch_size=batch_size, - reward_fn=lambda x, y, z: th.rand(size=(int(batch_size * 0.8), 1)) + reward_fn=lambda x, y, z: th.rand(size=(int(batch_size * 0.8), 1)), ) else: storage = storage_cls( @@ -62,7 +62,7 @@ def test_storage(storage_cls, device): action_space=env.action_space, device=device, num_envs=num_envs, - batch_size=batch_size + batch_size=batch_size, ) obs, infos = env.reset()