diff --git a/docs/guide/custom_policy.md b/docs/guide/custom_policy.md index 82c4616cc..13d7f0aad 100644 --- a/docs/guide/custom_policy.md +++ b/docs/guide/custom_policy.md @@ -9,6 +9,9 @@ other type of input features (MlpPolicies) and multiple different inputs (MultiI For A2C and PPO, continuous actions are clipped during training and testing (to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a `tanh()` transformation, which handles bounds more correctly. + +For A2C and PPO with normalized continuous action spaces, `policy_kwargs=dict(squash_mean_actions=True)` can be used +to squash the Gaussian distribution mean to `[-1, 1]`. ::: ## SB3 Policy diff --git a/docs/misc/changelog.md b/docs/misc/changelog.md index d58c42090..a541e59b9 100644 --- a/docs/misc/changelog.md +++ b/docs/misc/changelog.md @@ -11,6 +11,8 @@ ### New Features: +- Added `squash_mean_actions` policy option for A2C/PPO to tanh-squash the mean of `DiagGaussianDistribution` + ### Bug Fixes: - Fixed deprecated error Taxi-v3 from gymnasium v1.3.0 in tests diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 67ffe7f5d..2f2be47b8 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -137,7 +137,9 @@ def __init__(self, action_dim: int): super().__init__() self.action_dim = action_dim - def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]: + def proba_distribution_net( + self, latent_dim: int, log_std_init: float = 0.0, squash_mean_actions: bool = False + ) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the @@ -145,9 +147,11 @@ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation + :param squash_mean_actions: Whether to squash the mean actions using a tanh function. :return: """ - mean_actions = nn.Linear(latent_dim, self.action_dim) + mean_actions_net = nn.Linear(latent_dim, self.action_dim) + mean_actions = nn.Sequential(mean_actions_net, nn.Tanh()) if squash_mean_actions else mean_actions_net # TODO: allow action dependent std log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 0f75c5327..b99115a68 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -433,6 +433,8 @@ class ActorCriticPolicy(BasePolicy): above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. + :param squash_mean_actions: Whether to squash the mean actions using a tanh function. + This is only available when using ``DiagGaussianDistribution``. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. @@ -464,6 +466,7 @@ def __init__( normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, + squash_mean_actions: bool = False, ): if optimizer_kwargs is None: optimizer_kwargs = {} @@ -514,9 +517,16 @@ def __init__( self.vf_features_extractor = self.make_features_extractor() self.log_std_init = log_std_init + self.squash_mean_actions = squash_mean_actions dist_kwargs = None assert not (squash_output and not use_sde), "squash_output=True is only available when using gSDE (use_sde=True)" + assert not ( + squash_mean_actions and use_sde + ), "squash_mean_actions=True is only available without gSDE (use_sde=False). Use squash_output=True when using gSDE." + assert not ( + squash_mean_actions and not isinstance(action_space, spaces.Box) + ), "squash_mean_actions=True is only available for Box action spaces" # Keyword arguments for gSDE distribution if use_sde: dist_kwargs = { @@ -546,6 +556,7 @@ def _get_constructor_parameters(self) -> dict[str, Any]: use_sde=self.use_sde, log_std_init=self.log_std_init, squash_output=default_none_kwargs["squash_output"], + squash_mean_actions=self.squash_mean_actions, full_std=default_none_kwargs["full_std"], use_expln=default_none_kwargs["use_expln"], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone @@ -595,7 +606,7 @@ def _build(self, lr_schedule: Schedule) -> None: if isinstance(self.action_dist, DiagGaussianDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=latent_dim_pi, log_std_init=self.log_std_init + latent_dim=latent_dim_pi, log_std_init=self.log_std_init, squash_mean_actions=self.squash_mean_actions ) elif isinstance(self.action_dist, StateDependentNoiseDistribution): self.action_net, self.log_std = self.action_dist.proba_distribution_net( @@ -783,6 +794,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. + :param squash_mean_actions: Whether to squash the mean actions using a tanh function. + This is only available when using ``DiagGaussianDistribution``. :param features_extractor_class: Features extractor to use. :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. @@ -814,6 +827,7 @@ def __init__( normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, + squash_mean_actions: bool = False, ): super().__init__( observation_space, @@ -833,6 +847,7 @@ def __init__( normalize_images, optimizer_class, optimizer_kwargs, + squash_mean_actions=squash_mean_actions, ) @@ -856,6 +871,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy): above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: Whether to squash the output using a tanh function, this allows to ensure boundaries when using gSDE. + :param squash_mean_actions: Whether to squash the mean actions using a tanh function. + This is only available when using ``DiagGaussianDistribution``. :param features_extractor_class: Uses the CombinedExtractor :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. @@ -887,6 +904,7 @@ def __init__( normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: dict[str, Any] | None = None, + squash_mean_actions: bool = False, ): super().__init__( observation_space, @@ -906,6 +924,7 @@ def __init__( normalize_images, optimizer_class, optimizer_kwargs, + squash_mean_actions=squash_mean_actions, ) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 0f8ae6253..cb2afc79c 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -53,6 +53,45 @@ def test_squashed_gaussian(model_class): assert th.max(th.abs(actions)) <= 1.0 +def test_squashed_mean_diag_gaussian(): + dist = DiagGaussianDistribution(N_ACTIONS) + mean_actions_net, log_std = dist.proba_distribution_net(N_FEATURES, squash_mean_actions=True) + + assert isinstance(mean_actions_net, th.nn.Sequential) + mean_actions = mean_actions_net(th.ones(4, N_FEATURES) * 100.0) + assert th.max(th.abs(mean_actions)) <= 1.0 + + dist = dist.proba_distribution(mean_actions, log_std) + assert th.max(th.abs(dist.mode())) <= 1.0 + + +@pytest.mark.parametrize("model_class", [A2C, PPO]) +def test_squashed_mean_actions_policy(model_class, tmp_path): + kwargs = {"n_steps": 64} + if model_class == PPO: + kwargs["n_epochs"] = 1 + model = model_class( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[], squash_mean_actions=True), + **kwargs, + ) + + assert model.policy.squash_mean_actions + assert isinstance(model.policy.action_net, th.nn.Sequential) + + random_obs = np.array([model.observation_space.sample() for _ in range(4)]) + obs_tensor, _ = model.policy.obs_to_tensor(random_obs) + distribution = model.policy.get_distribution(obs_tensor) + assert isinstance(distribution, DiagGaussianDistribution) + assert th.max(th.abs(distribution.distribution.mean)) <= 1.0 + + model.save(tmp_path / "squashed_mean_actions_model.zip") + loaded_model = model_class.load(tmp_path / "squashed_mean_actions_model.zip") + assert loaded_model.policy.squash_mean_actions + assert isinstance(loaded_model.policy.action_net, th.nn.Sequential) + + @pytest.fixture() def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]: """ diff --git a/tests/test_sde.py b/tests/test_sde.py index 8aabfc094..6a73e4c31 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -66,6 +66,14 @@ def test_only_sde_squashed(): PPO("MlpPolicy", "Pendulum-v1", use_sde=False, policy_kwargs=dict(squash_output=True)) +def test_squashed_mean_actions_requires_diag_gaussian(): + with pytest.raises(AssertionError, match=r"without gSDE"): + PPO("MlpPolicy", "Pendulum-v1", use_sde=True, policy_kwargs=dict(squash_mean_actions=True)) + + with pytest.raises(AssertionError, match=r"Box action spaces"): + PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(squash_mean_actions=True)) + + @pytest.mark.parametrize("model_class", [SAC, A2C, PPO]) @pytest.mark.parametrize("use_expln", [False, True]) @pytest.mark.parametrize("squash_output", [False, True])