diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 2aa7a19b..569ba731 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -4,6 +4,7 @@ from sb3_contrib.crossq import CrossQ from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.ppo_recurrent import RecurrentPPO +from sb3_contrib.ppo_hybrid import HybridPPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO @@ -21,4 +22,5 @@ "CrossQ", "MaskablePPO", "RecurrentPPO", + "HybridPPO", ] diff --git a/sb3_contrib/common/envs/__init__.py b/sb3_contrib/common/envs/__init__.py index e9f740b9..68bb6be8 100644 --- a/sb3_contrib/common/envs/__init__.py +++ b/sb3_contrib/common/envs/__init__.py @@ -3,5 +3,8 @@ InvalidActionEnvMultiBinary, InvalidActionEnvMultiDiscrete, ) +from sb3_contrib.common.envs.hybrid_actions_env import ( + CatchingPointEnv +) -__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete"] +__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete", "CatchingPointEnv"] diff --git a/sb3_contrib/common/envs/hybrid_actions_env.py b/sb3_contrib/common/envs/hybrid_actions_env.py new file mode 100644 index 00000000..fe0da36c --- /dev/null +++ b/sb3_contrib/common/envs/hybrid_actions_env.py @@ -0,0 +1,125 @@ +import gymnasium as gym +from gymnasium import spaces +import numpy as np + + +class CatchingPointEnv(gym.Env): + """ + Enviornment for Hybrid PPO for the 'Catching Point' task of the paper + 'Hybrid Actor-Critic Reinforcement Learning in Parameterized Action Space', Fan et al. + (https://arxiv.org/pdf/1903.01344) + """ + + def __init__( + self, + arena_size: float = 1.0, + move_dist: float = 0.05, + catch_radius: float = 0.05, + max_catches: float = 10, + max_steps: float = 100 + ): + super().__init__() + self.max_steps = max_steps + self.max_catches = max_catches + self.arena_size = arena_size + self.move_dist = move_dist + self.catch_radius = catch_radius + + # action space + self.action_space = spaces.Tuple( + spaces=( + spaces.MultiDiscrete([2]), # MOVE=0, CATCH=1 + spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32) # direction + ) + ) + + # observation: [agent_x, agent_y, target_x, target_y, catches_left, step_norm] + obs_low = np.array([-arena_size, -arena_size, -arena_size, -arena_size, 0.0, 0.0], dtype=np.float32) + obs_high= np.array([ arena_size, arena_size, arena_size, arena_size, float(max_catches), 1.0], dtype=np.float32) + self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32) + + def reset(self, **kwargs) -> tuple[np.ndarray, dict]: + """ + Reset the environment to an initial state and return the initial observation. + """ + self.agent_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) + self.target_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32) + self.catches_used = 0 + self.step_count = 0 + return self._get_obs(), {} + + def step(self, concat_actions: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]: + """ + Step the environment with the given actions. + Compatible with VecEnv interface. + + :param concat_actions: A concatenated array containing both discrete and continuous actions. + :return: observation, reward, terminated, truncated, info + """ + # Unstack the concatenated actions back into discrete and continuous + n_discrete_actions = self.action_space[0].nvec.size if isinstance(self.action_space[0], spaces.MultiDiscrete) else 1 + actions_d = concat_actions[:n_discrete_actions].astype(int) + actions_c = concat_actions[n_discrete_actions:] + + # actual step logic + return self._step(actions_d, actions_c) + + def _step(self, actions_d: np.ndarray, actions_c: np.ndarray) -> tuple[np.ndarray, float, bool, dict]: + """ + Take a step in the environment using the provided actions. + + :param actions_d: Discrete action + :param actions_c: Continuous action + :return: observation, reward, done, info + """ + action_d = actions_d.item() # only 1 discrete action -> extract scalar + dir_vec = actions_c + reward = 0.0 + terminated = False + truncated = False + + # step penalty + reward = -0.01 + + # MOVE + if action_d == 0: + norm = np.linalg.norm(dir_vec) + dir_u = dir_vec / norm + self.agent_pos = (self.agent_pos + dir_u * self.move_dist).astype(np.float32) + # clamp to arena + self.agent_pos = np.clip(self.agent_pos, -self.arena_size, self.arena_size) + + # CATCH + else: + reward -= 0.05 # catch attempt penalty + self.catches_used += 1 + dist = np.linalg.norm(self.agent_pos - self.target_pos) + if dist <= self.catch_radius: + reward = 1.0 # caught the target + terminated = True # target caught: natural termination + else: + if self.catches_used >= self.max_catches: + reward = -1.0 # failed to catch within max catches + terminated = True # max catches reached: natural termination + + self.step_count += 1 + if self.step_count >= self.max_steps: + reward = -1.0 # failed to catch within max steps + truncated = True # max steps reached: truncation + + obs = self._get_obs() + info = {"caught": (reward > 0)} + return obs, float(reward), terminated, truncated, info + + def _get_obs(self) -> np.ndarray: + """ + Get the current observation. + """ + steps_left_norm = self.step_count / self.max_steps + catches_left_norm = (self.max_catches - self.catches_used) / self.max_catches + obs = np.concatenate(( + self.agent_pos, + self.target_pos, + np.array([catches_left_norm, steps_left_norm]) + ), dtype=np.float32) + return obs diff --git a/sb3_contrib/common/hybrid/__init__.py b/sb3_contrib/common/hybrid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/common/hybrid/distributions.py b/sb3_contrib/common/hybrid/distributions.py new file mode 100644 index 00000000..fc8f81e8 --- /dev/null +++ b/sb3_contrib/common/hybrid/distributions.py @@ -0,0 +1,204 @@ +import numpy as np +import torch as th +from torch import nn +from typing import Any, Optional, TypeVar, Union +from stable_baselines3.common.distributions import Distribution +from torch.distributions import Categorical, Normal +from gymnasium import spaces + + +SelfHybridDistribution = TypeVar("SelfHybridDistribution", bound="HybridDistribution") + + +class HybridDistributionNet(nn.Module): + """ + Base class for hybrid distributions that handle both discrete and continuous actions. + This class should be extended to implement specific hybrid distributions. + """ + + def __init__(self, latent_dim: int, categorical_dimensions: np.ndarray, n_continuous: int): + """ + Constructor. + + :param latent_dim: Dimension of the latent vector from the policy network + :param categorical_dimensions: Array specifying the number of discrete actions for each categorical distribution + :param n_continuous: Number of continuous actions + """ + super().__init__() + # For discrete action space + self.categorical_nets = nn.ModuleList([nn.Linear(latent_dim, out_dim) for out_dim in categorical_dimensions]) + # For continuous action space + self.gaussian_net = nn.Linear(latent_dim, n_continuous) + + def forward(self, latent: th.Tensor) -> tuple[list[th.Tensor], th.Tensor]: + """ + Forward pass through all categorical nets and the gaussian net. + + :param latent: Latent tensor input + :return: Tuple (list of categorical outputs, gaussian output) + """ + categorical_outputs = [net(latent) for net in self.categorical_nets] + gaussian_output = self.gaussian_net(latent) + return categorical_outputs, gaussian_output + + +class Hybrid(th.distributions.Distribution): + """ + A hybrid distribution that combines multiple categorical distributions for discrete actions + and a Gaussian distribution for continuous actions. + """ + + def __init__(self, + probs: Optional[tuple[list[th.Tensor], th.Tensor]] = None, + logits: Optional[tuple[list[th.Tensor], th.Tensor]] = None, + validate_args: Optional[bool] = None, + ): + super().__init__() + categorical_logits: list[th.Tensor] = logits[0] + gaussian_means: th.Tensor = logits[1] + self.categorical_dists = [Categorical(logits=logit) for logit in categorical_logits] + self.gaussian_dist = Normal(loc=gaussian_means, scale=th.ones_like(gaussian_means)) + + def sample(self) -> tuple[th.Tensor, th.Tensor]: + categorical_samples = [dist.sample() for dist in self.categorical_dists] + gaussian_samples = self.gaussian_dist.sample() + return th.stack(categorical_samples, dim=-1), gaussian_samples + + def log_prob(self, discrete_actions: th.Tensor, continuous_actions: th.Tensor) -> tuple[th.Tensor, th.Tensor]: + """ + Returns the log probability of the given actions, both discrete and continuous. + """ + discrete_action_list = discrete_actions.unbind(dim=-1) + categorical_log_probs = [dist.log_prob(action_d) for dist, action_d in zip(self.categorical_dists, discrete_action_list)] + gaussian_log_prob = self.gaussian_dist.log_prob(continuous_actions) + # TODO: check dimensions + return ( + th.sum(th.stack(categorical_log_probs, dim=-1), dim=-1), + th.sum(gaussian_log_prob, dim=-1) + ) + + def entropy(self) -> tuple[th.Tensor, th.Tensor]: + """ + Returns the entropy of the hybrid distribution, which is the sum of the entropies + of the categorical and gaussian components. + + :return: Tuple of (categorical entropy, gaussian entropy) + """ + categorical_entropies = [dist.entropy() for dist in self.categorical_dists] + # Sum entropies for all categorical distributions + categorical_entropy = th.sum(th.stack(categorical_entropies, dim=-1), dim=-1) + gaussian_entropy = self.gaussian_dist.entropy().sum(dim=-1) + return categorical_entropy, gaussian_entropy + + +class HybridDistribution(Distribution): + def __init__(self, categorical_dimensions: np.ndarray, n_continuous: int): + """ + Initialize the hybrid distribution with categorical and continuous components. + + :param categorical_dimensions: An array specifying the dimensions of the categorical actions. + :param n_continuous: The number of continuous actions. + """ + super().__init__() + self.categorical_dimensions = categorical_dimensions + self.n_continuous = n_continuous + self.categorical_dists = None + self.gaussian_dist = None + + def proba_distribution_net(self, latent_dim: int) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]: + """Create the layers and parameters that represent the distribution. + + Subclasses must define this, but the arguments and return type vary between + concrete classes.""" + action_net = HybridDistributionNet(latent_dim, self.categorical_dimensions, self.n_continuous) + return action_net + + def proba_distribution(self: SelfHybridDistribution, action_logits: tuple[list[th.Tensor], th.Tensor]) -> SelfHybridDistribution: + """Set parameters of the distribution. + + :return: self + """ + self.distribution = Hybrid(logits=action_logits) + return self + + def log_prob(self, discrete_actions: th.Tensor, continuous_actions: th.Tensor) -> tuple[th.Tensor, th.Tensor]: + """ + Returns the log likelihood + + :param x: the taken action + :return: The log likelihood of the distribution for discrete and continuous distributions + """ + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.log_prob(discrete_actions, continuous_actions) + + def entropy(self) -> tuple[th.Tensor, th.Tensor]: + """ + Returns Shannon's entropy of the probability + + :return: the entropy of discrete and continuous distributions + """ + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.entropy() + + def sample(self) -> tuple[th.Tensor, th.Tensor]: + """ + Returns a sample from the probability distribution + + :return: the stochastic action + """ + assert self.distribution is not None, "Must set distribution parameters" + return self.distribution.sample() + + def mode(self) -> th.Tensor: + """ + Returns the most likely action (deterministic output) + from the probability distribution + + :return: the stochastic action + """ + + def get_actions(self, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor]: + """ + Return actions according to the probability distribution. + + :param deterministic: + :return: + """ + if deterministic: + return self.mode() + return self.sample() + + def actions_from_params(self, *args, **kwargs) -> th.Tensor: + """ + Returns samples from the probability distribution + given its parameters. + + :return: actions + """ + + def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]: + """ + Returns samples and the associated log probabilities + from the probability distribution given its parameters. + + :return: actions and log prob + """ + + +def make_hybrid_proba_distribution(action_space: spaces.Tuple) -> HybridDistribution: + """ + Create a hybrid probability distribution for the given action space. + + :param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space. + :return: A HybridDistribution object that handles the hybrid action space. + """ + assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space" + assert len(action_space.spaces) == 2, "Action space must contain exactly 2 subspaces" + assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete" + assert isinstance(action_space.spaces[1], spaces.Box), "Second subspace must be Box" + assert len(action_space[1].shape) == 1, "Continuous action space must have a monodimensional shape (e.g., (n,))" + return HybridDistribution( + categorical_dimensions=action_space[0].nvec, + n_continuous=action_space[1].shape[0] + ) + diff --git a/sb3_contrib/common/hybrid/policies.py b/sb3_contrib/common/hybrid/policies.py new file mode 100644 index 00000000..00dc7aa0 --- /dev/null +++ b/sb3_contrib/common/hybrid/policies.py @@ -0,0 +1,305 @@ +from functools import partial +from typing import Any, Optional, Union +import warnings +import numpy as np +from stable_baselines3.common.policies import BasePolicy +from gymnasium import spaces +from stable_baselines3.common.type_aliases import PyTorchObs, Schedule +import torch as th +from torch import nn +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, +) + +from sb3_contrib.common.hybrid.distributions import HybridDistribution, make_hybrid_proba_distribution + + +class HybridActorCriticPolicy(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Tuple Action space containing a MultiDiscrete action space and a Box action space. + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param log_std_init: Initial value for the log standard deviation + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param share_features_extractor: If True, the features extractor is shared between the policy and value networks. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Tuple, # Type[spaces.MultiDiscrete, spaces.Box] + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + log_std_init: float = 0.0, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + share_features_extractor: bool = True, + normalize_images: bool = True, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, + ): + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super().__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + normalize_images=normalize_images, + squash_output=False, + ) + + # assert that the action space is compatible with its type hint + assert isinstance(action_space, spaces.Tuple), "Action space must be a gymnasium.spaces.Tuple" + assert len(action_space.spaces) == 2, "Action space Tuple must contain exactly two spaces" + assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First element of action space Tuple must be MultiDiscrete" + assert isinstance(action_space.spaces[1], spaces.Box), "Second element of action space Tuple must be Box" + + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are removed since SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = dict(pi=[64, 64], vf=[64, 64]) + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + # features extractor + self.share_features_extractor = share_features_extractor + self.features_extractor = self.make_features_extractor() + self.features_dim = self.features_extractor.features_dim + if self.share_features_extractor: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.features_extractor + else: + self.pi_features_extractor = self.features_extractor + self.vf_features_extractor = self.make_features_extractor() + + self.log_std_init = log_std_init + + # Action distribution + self.action_dist: HybridDistribution = make_hybrid_proba_distribution(action_space) + + self._build(lr_schedule) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + """ + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + self._build_mlp_extractor() + + # Create action net and value net + self.action_net = self.action_dist.proba_distribution_net(latent_dim=self.mlp_extractor.latent_dim_pi) + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + if not self.share_features_extractor: + # Note(antonin): this is to keep SB3 results + # consistent, see GH#1148 + del module_gains[self.features_extractor] + module_gains[self.pi_features_extractor] = np.sqrt(2) + module_gains[self.vf_features_extractor] = np.sqrt(2) + + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class( + self.parameters(), + lr=lr_schedule(1), # type: ignore[call-arg] + **self.optimizer_kwargs, + ) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions_d, actions_c = distribution.get_actions(deterministic=deterministic) + log_prob_d, log_prob_c = distribution.log_prob(actions_d, actions_c) + return actions_d, actions_c, values, log_prob_d, log_prob_c + + def extract_features( # type: ignore[override] + self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None + ) -> Union[th.Tensor, tuple[th.Tensor, th.Tensor]]: + """ + Preprocess the observation if needed and extract features. + + :param obs: Observation + :param features_extractor: The features extractor to use. If None, then ``self.features_extractor`` is used. + :return: The extracted features. If features extractor is not shared, returns a tuple with the + features for the actor and the features for the critic. + """ + if self.share_features_extractor: + return super().extract_features(obs, features_extractor or self.features_extractor) + else: + if features_extractor is not None: + warnings.warn( + "Provided features_extractor will be ignored because the features extractor is not shared.", + UserWarning, + ) + + pi_features = super().extract_features(obs, self.pi_features_extractor) + vf_features = super().extract_features(obs, self.vf_features_extractor) + return pi_features, vf_features + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> HybridDistribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + action_logits: tuple[list[th.Tensor], th.Tensor] = self.action_net(latent_pi) + return self.action_dist.proba_distribution(action_logits=action_logits) + + def get_distribution(self, obs: PyTorchObs) -> HybridDistribution: + """ + Get the current policy distribution given an observation. + + :param obs: Observation + :return: The action distribution + """ + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi = self.mlp_extractor.forward_actor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + distribution = self._get_action_dist_from_latent(latent_pi) + return distribution + + def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: Observation + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + return self.get_distribution(observation).get_actions(deterministic=deterministic) + + def evaluate_actions( + self, + obs: PyTorchObs, + actions_d: th.Tensor, + actions_c: th.Tensor + ) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation + :param actions_d: Discrete actions + :param actions_c: Continuous actions + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + if self.share_features_extractor: + latent_pi, latent_vf = self.mlp_extractor(features) + else: + pi_features, vf_features = features + latent_pi = self.mlp_extractor.forward_actor(pi_features) + latent_vf = self.mlp_extractor.forward_critic(vf_features) + distribution = self._get_action_dist_from_latent(latent_pi) + # log prob of discrete and continuous actions + log_prob: tuple[th.Tensor, th.Tensor] = distribution.log_prob(actions_d, actions_c) + # entropy of discrete and continuous actions + entropy: tuple[th.Tensor, th.Tensor] = distribution.entropy() + values = self.value_net(latent_vf) + return values, log_prob, entropy + + def predict_values(self, obs: PyTorchObs) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation + :return: the estimated values. + """ + features = super().extract_features(obs, self.vf_features_extractor) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + + +# TODO: check superclass +class HybridActorCriticCnnPolicy(HybridActorCriticPolicy): + pass + + +# TODO: check superclass +class HybridMultiInputActorCriticPolicy(HybridActorCriticPolicy): + pass \ No newline at end of file diff --git a/sb3_contrib/ppo_hybrid/__init__.py b/sb3_contrib/ppo_hybrid/__init__.py new file mode 100644 index 00000000..515da75a --- /dev/null +++ b/sb3_contrib/ppo_hybrid/__init__.py @@ -0,0 +1,5 @@ +from sb3_contrib.ppo_hybrid.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_contrib.ppo_hybrid.ppo_hybrid import HybridPPO +from sb3_contrib.ppo_hybrid.buffers import HybridActionsRolloutBuffer + +__all__ = ["CnnPolicy", "HybridPPO", "MlpPolicy", "MultiInputPolicy"] diff --git a/sb3_contrib/ppo_hybrid/buffers.py b/sb3_contrib/ppo_hybrid/buffers.py new file mode 100644 index 00000000..8d6539f0 --- /dev/null +++ b/sb3_contrib/ppo_hybrid/buffers.py @@ -0,0 +1,196 @@ +from collections.abc import Generator +import numpy as np +from stable_baselines3.common.buffers import RolloutBuffer +from gymnasium import spaces +from typing import Optional, Union, NamedTuple +import torch as th +from stable_baselines3.common.preprocessing import get_obs_shape +from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env import VecNormalize + + +class HybridActionsRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions_d: th.Tensor + actions_c: th.Tensor + old_values: th.Tensor + old_log_prob_d: th.Tensor + old_log_prob_c: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + +def get_action_dim(action_space: spaces.Tuple) -> tuple[int, int]: + """ + Get the dimension of the action space, + assumed to be the one of HybridPPO (Tuple[MultiDiscrete, Box]). + + :param action_space: Tuple action space containing MultiDiscrete and Box spaces + :return: (dim_d, dim_c) where dim_d is the discrete action dimension and dim_c the continuous action dimension. + """ + assert isinstance(action_space, spaces.Tuple), "Action space must be a Tuple space" + assert len(action_space.spaces) == 2, "Action space must contain exactly 2 subspaces" + assert isinstance(action_space.spaces[0], spaces.MultiDiscrete), "First subspace must be MultiDiscrete" + assert isinstance(action_space.spaces[1], spaces.Box), "Second subspace must be Box" + return ( + len(action_space.spaces[0].nvec), # discrete action dimension + int(np.prod(action_space.spaces[1].shape)) # continuous action dimension + ) + + +class HybridActionsRolloutBuffer(RolloutBuffer): + """ + Rollout buffer for hybrid action spaces (discrete + continuous). + Stores separate actions and log probabilities for discrete and continuous parts. + """ + + actions_d: np.ndarray + actions_c: np.ndarray + log_probs_d: np.ndarray + log_probs_c: np.ndarray + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Tuple, # Type[spaces.MultiDiscrete, spaces.Box] + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + # NOTE: it would be nice to use RolloutBuffer.__init__(), but BaseBuffer calls + # get_action_dim which is not compatible with Tuple action spaces. + + # BaseBuffer's constructor code (excluding call to ABC.__init__()) + self.buffer_size = buffer_size + self.observation_space = observation_space + self.action_space = action_space + self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment] + + self.action_dim: tuple[int, int] = get_action_dim(action_space) + self.pos = 0 + self.full = False + self.device = get_device(device) + self.n_envs = n_envs + + # RolloutBuffer's constructor code + self.gae_lambda = gae_lambda + self.gamma = gamma + self.generator_ready = False + self.reset() + + def reset(self) -> None: + self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=self.observation_space.dtype) + self.actions_d = np.zeros((self.buffer_size, self.n_envs, self.action_dim[0]), dtype=np.float32) + self.actions_c = np.zeros((self.buffer_size, self.n_envs, self.action_dim[1]), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs_d = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs_c = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + # single advantages buffer for discrete and continuous actions + # because it's estimated using only the values and rewards (not dependent on actions directly) + # (see Eq. 7 in Hybrid PPO paper) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + self.pos = 0 + self.full = False + + def add( + self, + obs: np.ndarray, + action_d: np.ndarray, + action_c: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob_d: th.Tensor, + log_prob_c: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action_d: Discrete action + :param action_c: Continuous action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob_d: log probabilities of the discrete action following the current policy. + :param log_prob_c: log probabilities of the continuous action following the current policy. + """ + # Reshape 0-d tensor to avoid error + if len(log_prob_d.shape) == 0: + log_prob_d = log_prob_d.reshape(-1, 1) + if len(log_prob_c.shape) == 0: + log_prob_c = log_prob_c.reshape(-1, 1) + + # copied from RolloutBuffer: + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs, *self.obs_shape)) + + # Adapted from RolloutBuffer: + # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 + action_d = action_d.reshape((self.n_envs, self.action_dim[0])) + action_c = action_c.reshape((self.n_envs, self.action_dim[1])) + + self.observations[self.pos] = obs + self.actions_d[self.pos] = action_d + self.actions_c[self.pos] = action_c + self.rewards[self.pos] = reward + self.episode_starts[self.pos] = episode_start + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs_d[self.pos] = log_prob_d.clone().cpu().numpy() + self.log_probs_c[self.pos] = log_prob_c.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[HybridActionsRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + _tensor_names = [ + "observations", + "actions_d", + "actions_c", + "values", + "log_probs_d", + "log_probs_c", + "advantages", + "returns", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None # TODO: check type hint + ) -> HybridActionsRolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions_d[batch_inds], + self.actions_c[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs_d[batch_inds].flatten(), + self.log_probs_c[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return HybridActionsRolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/sb3_contrib/ppo_hybrid/policies.py b/sb3_contrib/ppo_hybrid/policies.py new file mode 100644 index 00000000..15f60e58 --- /dev/null +++ b/sb3_contrib/ppo_hybrid/policies.py @@ -0,0 +1,9 @@ +from sb3_contrib.common.hybrid.policies import ( + HybridActorCriticPolicy, + HybridActorCriticCnnPolicy, + HybridMultiInputActorCriticPolicy, +) + +MlpPolicy = HybridActorCriticPolicy +CnnPolicy = HybridActorCriticCnnPolicy +MultiInputPolicy = HybridMultiInputActorCriticPolicy diff --git a/sb3_contrib/ppo_hybrid/ppo_hybrid.py b/sb3_contrib/ppo_hybrid/ppo_hybrid.py new file mode 100644 index 00000000..75416a98 --- /dev/null +++ b/sb3_contrib/ppo_hybrid/ppo_hybrid.py @@ -0,0 +1,438 @@ +from typing import Any, ClassVar, Optional, TypeVar, Union +import warnings +import torch as th +from torch.nn import functional as F +import numpy as np +from gymnasium import spaces +from sb3_contrib.common.hybrid.policies import HybridActorCriticPolicy, HybridActorCriticCnnPolicy, HybridMultiInputActorCriticPolicy +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.type_aliases import MaybeCallback, GymEnv, Schedule +from stable_baselines3.common.utils import obs_as_tensor +from sb3_contrib.ppo_hybrid.buffers import HybridActionsRolloutBuffer +from stable_baselines3.common.utils import explained_variance +from stable_baselines3.common.utils import FloatSchedule + +SelfHybridPPO = TypeVar("SelfHybridPPO", bound="HybridPPO") + + +class HybridPPO(OnPolicyAlgorithm): + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { + "MlpPolicy": HybridActorCriticPolicy, + "CnnPolicy": HybridActorCriticCnnPolicy, + "MultiInputPolicy": HybridMultiInputActorCriticPolicy, + } + + rollout_buffer: HybridActionsRolloutBuffer + policy: HybridActorCriticPolicy + + def __init__( + self, + policy: Union[str, type[HybridActorCriticPolicy]], + env: Union[GymEnv, str], # TODO: check if custom env needed to accept multiple actions + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: int = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: Union[float, Schedule] = 0.2, + clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + rollout_buffer_class: Optional[type[HybridActionsRolloutBuffer]] = None, # TODO: check if custom class needed to accept multiple actions + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, + target_kl: Optional[float] = None, + stats_window_size: int = 100, + tensorboard_log: Optional[str] = None, + policy_kwargs: Optional[dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + self.clip_range = clip_range + self.clip_range_vf = clip_range_vf + + super().__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, + stats_window_size=stats_window_size, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + seed=seed, + device=device, + _init_setup_model=_init_setup_model, + supported_action_spaces=(spaces.Tuple,), + ) + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert buffer_size > 1 or ( + not normalize_advantage + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + + self.batch_size = batch_size + self.n_epochs = n_epochs + self.normalize_advantage = normalize_advantage + self.target_kl = target_kl + + # NOTE: _setup_model already called in super().__init__(), but PPO and MaskablePPO call it again here, so we copy that behavior + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + self._setup_lr_schedule() + self.set_random_seed(self.seed) + + if self.rollout_buffer_class is None: + # TODO: maybe extend if buffers for Dict obs is implemented + self.rollout_buffer_class = HybridActionsRolloutBuffer + + self.rollout_buffer = self.rollout_buffer_class( + self.n_steps, + self.observation_space, + self.action_space, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + **self.rollout_buffer_kwargs, + ) + + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs, + ) + self.policy = self.policy.to(self.device) + + if not isinstance(self.policy, HybridActorCriticPolicy): + raise ValueError("Policy must subclass HybridActorCriticPolicy") + + # Initialize schedules for policy/value clipping + self.clip_range = FloatSchedule(self.clip_range) + if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" + + self.clip_range_vf = FloatSchedule(self.clip_range_vf) + + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: HybridActionsRolloutBuffer, + n_rollout_steps: int, + ) -> bool: + """ + Collect experiences using the current policy and fill a ``HybridActionsRolloutBuffer``. + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + + :param env: The training environment + :param callback: Callback that will be called at each step + (and at the beginning and end of the rollout) + :param rollout_buffer: Buffer to fill with rollouts + :param n_rollout_steps: Number of experiences to collect per environment + :return: True if function returned with at least `n_rollout_steps` + collected, False if callback terminated rollout prematurely. + """ + assert self._last_obs is not None, "No previous observation was provided" + # Switch to eval mode (this affects batch norm / dropout) + self.policy.set_training_mode(False) + + n_steps = 0 + rollout_buffer.reset() + # Sample new weights for the state dependent exploration + if self.use_sde: + self.policy.reset_noise(env.num_envs) + + callback.on_rollout_start() + + while n_steps < n_rollout_steps: + if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: + # Sample a new noise matrix + self.policy.reset_noise(env.num_envs) + + with th.no_grad(): + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) + actions_d, actions_c, values, log_probs_d, log_probs_c = self.policy(obs_tensor) + actions_d = actions_d.cpu().numpy() + actions_c = actions_c.cpu().numpy() + + # Rescale and perform action + clipped_actions_c = actions_c + + # action_space is spaces.Tuple[spaces.MultiDiscrete, spaces.Box] + if self.policy.squash_output: + # Unscale the actions to match env bounds + # if they were previously squashed (scaled in [-1, 1]) + clipped_actions_c = self.policy.unscale_action(clipped_actions_c) + else: + # Otherwise, clip the actions to avoid out of bound error + # as we are sampling from an unbounded Gaussian distribution + clipped_actions_c = np.clip(actions_c, self.action_space[1].low, self.action_space[1].high) + + # concat discrete and continuous actions to make them compatible with vectorized envs + concat_actions = np.concatenate([ + actions_d.reshape(env.num_envs, -1), # TODO: check if reshape is needed + clipped_actions_c.reshape(env.num_envs, -1) # TODO: check if reshape is needed + ], axis=1) + + new_obs, rewards, dones, infos = env.step(concat_actions) + + self.num_timesteps += env.num_envs + + # Give access to local variables + callback.update_locals(locals()) + if not callback.on_step(): + return False + + self._update_info_buffer(infos, dones) + n_steps += 1 + + # Reshape in case of discrete action + if isinstance(self.action_space[0], spaces.Discrete): + # Reshape in case of discrete action + actions_d = actions_d.reshape(-1, 1) + elif isinstance(self.action_space[0], spaces.MultiDiscrete): + actions_d = actions_d.reshape(-1, self.action_space[0].nvec.shape[0]) + + # Handle timeout by bootstrapping with value function + # see GitHub issue #633 + for idx, done in enumerate(dones): + if ( + done + and infos[idx].get("terminal_observation") is not None + and infos[idx].get("TimeLimit.truncated", False) + ): + terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] + with th.no_grad(): + terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] + rewards[idx] += self.gamma * terminal_value + + rollout_buffer.add( + self._last_obs, # type: ignore[arg-type] + actions_d, + clipped_actions_c, + rewards, + self._last_episode_starts, # type: ignore[arg-type] + values, + log_probs_d, + log_probs_c + ) + self._last_obs = new_obs # type: ignore[assignment] + self._last_episode_starts = dones + + with th.no_grad(): + # Compute value for the last timestep + values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) # type: ignore[arg-type] + + rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) + + callback.update_locals(locals()) + + callback.on_rollout_end() + + return True + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + # Compute current clip range + clip_range = self.clip_range(self._current_progress_remaining) # type: ignore[operator] + # Optional: clip range for the value function + if self.clip_range_vf is not None: + clip_range_vf = self.clip_range_vf(self._current_progress_remaining) # type: ignore[operator] + + entropy_losses = [] + pg_losses_d, pg_losses_c, value_losses = [], [], [] + clip_fractions_d, clip_fractions_c = [], [] + + continue_training = True + # train for n_epochs epochs + for epoch in range(self.n_epochs): + approx_kl_divs = [] + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions_d = rollout_data.actions_d + actions_c = rollout_data.actions_c + if isinstance(self.action_space[0], spaces.Discrete): + # Reshape in case of discrete action + actions_d = actions_d.reshape(-1, 1) + elif isinstance(self.action_space[0], spaces.MultiDiscrete): + actions_d = actions_d.reshape(-1, self.action_space[0].nvec.shape[0]) + + values, log_probs, entropy = self.policy.evaluate_actions(rollout_data.observations, actions_d, actions_c) + log_prob_d, log_prob_c = log_probs + entropy_d, entropy_c = entropy + values = values.flatten() + + # Normalize advantage + advantages = rollout_data.advantages + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if self.normalize_advantage and len(advantages) > 1: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + # ratio between old and new policy, for discrete and continuous actions. + # Should be one at the first iteration + ratio_d = th.exp(log_prob_d - rollout_data.old_log_prob_d) + ratio_c = th.exp(log_prob_c - rollout_data.old_log_prob_c) + + # clipped surrogate loss for discrete actions + policy_loss_d_1 = advantages * ratio_d + policy_loss_d_2 = advantages * th.clamp(ratio_d, 1 - clip_range, 1 + clip_range) + policy_loss_d = -th.min(policy_loss_d_1, policy_loss_d_2).mean() + + # clipped surrogate loss for continuous actions + policy_loss_c_1 = advantages * ratio_c + policy_loss_c_2 = advantages * th.clamp(ratio_c, 1 - clip_range, 1 + clip_range) + policy_loss_c = -th.min(policy_loss_c_1, policy_loss_c_2).mean() + + # Logging + pg_losses_d.append(policy_loss_d.item()) + clip_fraction_d = th.mean((th.abs(ratio_d - 1) > clip_range).float()).item() + clip_fractions_d.append(clip_fraction_d) + pg_losses_c.append(policy_loss_c.item()) + clip_fraction_c = th.mean((th.abs(ratio_c - 1) > clip_range).float()).item() + clip_fractions_c.append(clip_fraction_c) + + # Value loss + if self.clip_range_vf is None: + # No clipping + values_pred = values + else: + # Clip the difference between old and new value + # NOTE: this depends on the reward scaling + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) + # Value loss using the TD(gae_lambda) target + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + # Entropy loss favor exploration + if entropy is None: + # Approximate entropy when no analytical form + entropy_loss_d = -th.mean(-log_prob_d) + entropy_loss_c = -th.mean(-log_prob_c) + entropy_loss = entropy_loss_d + entropy_loss_c + else: + entropy_loss = -th.mean(entropy_d) - th.mean(entropy_c) + entropy_losses.append(entropy_loss.item()) + + # total loss function + loss = 0.5 * (policy_loss_d + policy_loss_c) + self.ent_coef * entropy_loss + self.vf_coef * value_loss + + # Calculate approximate form of reverse KL Divergence for early stopping + # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 + # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 + # and Schulman blog: http://joschu.net/blog/kl-approx.html + # Note: using the max KL divergence between discrete and continuous actions to stay conservative + with th.no_grad(): + log_ratio_d = log_prob_d - rollout_data.old_log_prob_d + log_ratio_c = log_prob_c - rollout_data.old_log_prob_c + approx_kl_div_d = th.mean((th.exp(log_ratio_d) - 1) - log_ratio_d).cpu().numpy() + approx_kl_div_c = th.mean((th.exp(log_ratio_c) - 1) - log_ratio_c).cpu().numpy() + approx_kl_div = max(approx_kl_div_d, approx_kl_div_c) + approx_kl_divs.append(approx_kl_div) + + if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: + continue_training = False + if self.verbose >= 1: + print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + break + + # Optimization step + self.policy.optimizer.zero_grad() + loss.backward() + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + self._n_updates += 1 + if not continue_training: + break + + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/entropy_loss", np.mean(entropy_losses)) + self.logger.record("train/policy_gradient_discrete_loss", np.mean(pg_losses_d)) + self.logger.record("train/policy_gradient_continuous_loss", np.mean(pg_losses_c)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) + self.logger.record("train/clip_fraction_discrete", np.mean(clip_fractions_d)) + self.logger.record("train/clip_fraction_continuous", np.mean(clip_fractions_c)) + self.logger.record("train/loss", loss.item()) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/clip_range", clip_range) + if self.clip_range_vf is not None: + self.logger.record("train/clip_range_vf", clip_range_vf) + + def learn( + self: SelfHybridPPO, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "Hybrid PPO", + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfHybridPPO: + return super().learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + tb_log_name=tb_log_name, + reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, + )