Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/guide/custom_policy.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ other type of input features (MlpPolicies) and multiple different inputs (MultiI
For A2C and PPO, continuous actions are clipped during training and testing
(to avoid out of bound error). SAC, DDPG and TD3 squash the action, using a `tanh()` transformation,
which handles bounds more correctly.

For A2C and PPO with normalized continuous action spaces, `policy_kwargs=dict(squash_mean_actions=True)` can be used
to squash the Gaussian distribution mean to `[-1, 1]`.
:::

## SB3 Policy
Expand Down
2 changes: 2 additions & 0 deletions docs/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

### New Features:

- Added `squash_mean_actions` policy option for A2C/PPO to tanh-squash the mean of `DiagGaussianDistribution`

### Bug Fixes:
- Fixed deprecated error Taxi-v3 from gymnasium v1.3.0 in tests

Expand Down
8 changes: 6 additions & 2 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,21 @@ def __init__(self, action_dim: int):
super().__init__()
self.action_dim = action_dim

def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]:
def proba_distribution_net(
self, latent_dim: int, log_std_init: float = 0.0, squash_mean_actions: bool = False
) -> tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)

:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:param squash_mean_actions: Whether to squash the mean actions using a tanh function.
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
mean_actions_net = nn.Linear(latent_dim, self.action_dim)
mean_actions = nn.Sequential(mean_actions_net, nn.Tanh()) if squash_mean_actions else mean_actions_net
# TODO: allow action dependent std
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
Expand Down
21 changes: 20 additions & 1 deletion stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ class ActorCriticPolicy(BasePolicy):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param squash_mean_actions: Whether to squash the mean actions using a tanh function.
This is only available when using ``DiagGaussianDistribution``.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
Expand Down Expand Up @@ -464,6 +466,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
squash_mean_actions: bool = False,
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
Expand Down Expand Up @@ -514,9 +517,16 @@ def __init__(
self.vf_features_extractor = self.make_features_extractor()

self.log_std_init = log_std_init
self.squash_mean_actions = squash_mean_actions
dist_kwargs = None

assert not (squash_output and not use_sde), "squash_output=True is only available when using gSDE (use_sde=True)"
assert not (
squash_mean_actions and use_sde
), "squash_mean_actions=True is only available without gSDE (use_sde=False). Use squash_output=True when using gSDE."
assert not (
squash_mean_actions and not isinstance(action_space, spaces.Box)
), "squash_mean_actions=True is only available for Box action spaces"
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
Expand Down Expand Up @@ -546,6 +556,7 @@ def _get_constructor_parameters(self) -> dict[str, Any]:
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
squash_mean_actions=self.squash_mean_actions,
full_std=default_none_kwargs["full_std"],
use_expln=default_none_kwargs["use_expln"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
Expand Down Expand Up @@ -595,7 +606,7 @@ def _build(self, lr_schedule: Schedule) -> None:

if isinstance(self.action_dist, DiagGaussianDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi, log_std_init=self.log_std_init
latent_dim=latent_dim_pi, log_std_init=self.log_std_init, squash_mean_actions=self.squash_mean_actions
)
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
Expand Down Expand Up @@ -783,6 +794,8 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param squash_mean_actions: Whether to squash the mean actions using a tanh function.
This is only available when using ``DiagGaussianDistribution``.
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
Expand Down Expand Up @@ -814,6 +827,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
squash_mean_actions: bool = False,
):
super().__init__(
observation_space,
Expand All @@ -833,6 +847,7 @@ def __init__(
normalize_images,
optimizer_class,
optimizer_kwargs,
squash_mean_actions=squash_mean_actions,
)


Expand All @@ -856,6 +871,8 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: Whether to squash the output using a tanh function,
this allows to ensure boundaries when using gSDE.
:param squash_mean_actions: Whether to squash the mean actions using a tanh function.
This is only available when using ``DiagGaussianDistribution``.
:param features_extractor_class: Uses the CombinedExtractor
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
Expand Down Expand Up @@ -887,6 +904,7 @@ def __init__(
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
squash_mean_actions: bool = False,
):
super().__init__(
observation_space,
Expand All @@ -906,6 +924,7 @@ def __init__(
normalize_images,
optimizer_class,
optimizer_kwargs,
squash_mean_actions=squash_mean_actions,
)


Expand Down
39 changes: 39 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,45 @@ def test_squashed_gaussian(model_class):
assert th.max(th.abs(actions)) <= 1.0


def test_squashed_mean_diag_gaussian():
dist = DiagGaussianDistribution(N_ACTIONS)
mean_actions_net, log_std = dist.proba_distribution_net(N_FEATURES, squash_mean_actions=True)

assert isinstance(mean_actions_net, th.nn.Sequential)
mean_actions = mean_actions_net(th.ones(4, N_FEATURES) * 100.0)
assert th.max(th.abs(mean_actions)) <= 1.0

dist = dist.proba_distribution(mean_actions, log_std)
assert th.max(th.abs(dist.mode())) <= 1.0


@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_squashed_mean_actions_policy(model_class, tmp_path):
kwargs = {"n_steps": 64}
if model_class == PPO:
kwargs["n_epochs"] = 1
model = model_class(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[], squash_mean_actions=True),
**kwargs,
)

assert model.policy.squash_mean_actions
assert isinstance(model.policy.action_net, th.nn.Sequential)

random_obs = np.array([model.observation_space.sample() for _ in range(4)])
obs_tensor, _ = model.policy.obs_to_tensor(random_obs)
distribution = model.policy.get_distribution(obs_tensor)
assert isinstance(distribution, DiagGaussianDistribution)
assert th.max(th.abs(distribution.distribution.mean)) <= 1.0

model.save(tmp_path / "squashed_mean_actions_model.zip")
loaded_model = model_class.load(tmp_path / "squashed_mean_actions_model.zip")
assert loaded_model.policy.squash_mean_actions
assert isinstance(loaded_model.policy.action_net, th.nn.Sequential)


@pytest.fixture()
def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]:
"""
Expand Down
8 changes: 8 additions & 0 deletions tests/test_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def test_only_sde_squashed():
PPO("MlpPolicy", "Pendulum-v1", use_sde=False, policy_kwargs=dict(squash_output=True))


def test_squashed_mean_actions_requires_diag_gaussian():
with pytest.raises(AssertionError, match=r"without gSDE"):
PPO("MlpPolicy", "Pendulum-v1", use_sde=True, policy_kwargs=dict(squash_mean_actions=True))

with pytest.raises(AssertionError, match=r"Box action spaces"):
PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(squash_mean_actions=True))


@pytest.mark.parametrize("model_class", [SAC, A2C, PPO])
@pytest.mark.parametrize("use_expln", [False, True])
@pytest.mark.parametrize("squash_output", [False, True])
Expand Down