Skip to content

Add opt-in tanh squashing for DiagGaussianDistribution mean actions#2249

Draft
cgliner wants to merge 1 commit into
DLR-RM:masterfrom
cgliner:feat/squash-mean-actions
Draft

Add opt-in tanh squashing for DiagGaussianDistribution mean actions#2249
cgliner wants to merge 1 commit into
DLR-RM:masterfrom
cgliner:feat/squash-mean-actions

Conversation

@cgliner
Copy link
Copy Markdown

@cgliner cgliner commented May 5, 2026

Description

This PR adds an opt-in squash_mean_actions policy option for A2C/PPO policies using DiagGaussianDistribution.

When enabled with:

policy_kwargs=dict(squash_mean_actions=True)

the Gaussian mean action network is wrapped with nn.Tanh(), constraining the mean actions to [-1, 1].

The default behavior is unchanged. The option is only available for non-gSDE Box action spaces. For gSDE, SB3 already has the existing squash_output=True path.

This PR also adds:

  • tests for the squashed mean action layer,
  • A2C/PPO integration and save/load coverage,
  • validation tests for unsupported usage,
  • documentation,
  • changelog entry.

I used OpenAI Codex to help implement the change, add tests, and run local verification.

Motivation and Context

For continuous Box action spaces, especially normalized spaces like [-1, 1], the current diagonal Gaussian policy can produce unbounded mean actions. Those actions are then clipped to the action space bounds.

That clipping can lead to poor behavior near action-space edges, because the policy may learn means far outside the valid range while the environment only sees clipped boundary actions. This PR provides an opt-in way to keep the deterministic Gaussian mean inside the normalized action range with a smooth tanh transformation.

This does not fully replace SAC-style squashed Gaussian distributions, but it gives A2C/PPO users a simple bounded-mean option while preserving backward compatibility.

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (docs/misc/changelog.md) (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have opened an associated PR on the SB3-Contrib repository (if necessary)
  • I have opened an associated PR on the RL-Zoo3 repository (if necessary)
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)
  • I have checked that the documentation builds using make doc (required)

SquashedDiagGaussianDistribution squashes the sampled action distribution. Your change squashes only the Gaussian mean.

Concretely:

SquashedDiagGaussianDistribution does this:

gaussian_action ~ Normal(mean, std)
action = tanh(gaussian_action)

So both stochastic samples and deterministic actions are bounded in [-1, 1]. Because the distribution itself is transformed by tanh, it also needs a log-probability correction using the change-of-variables term:

log_prob -= log(1 - action^2 + epsilon)

That is what SAC uses, where the policy really is a tanh-transformed Gaussian.

Your squash_mean_actions=True option does this instead:

mean = tanh(linear(latent))
action ~ Normal(mean, std)

So the center of the Gaussian is bounded in [-1, 1], but stochastic samples can still go outside the action bounds and will still be clipped by A2C/PPO as before. The probability distribution remains a normal Gaussian, so no tanh log-prob correction is needed.

Practical difference:

  • SquashedDiagGaussianDistribution: bounded actions, transformed distribution, corrected log-probs, no analytical entropy.
  • squash_mean_actions=True: bounded deterministic mean, ordinary Gaussian samples, ordinary log-probs and entropy, still may require clipping sampled actions.

So your change is a lighter, more conservative option for A2C/PPO. It improves deterministic/mean behavior near Box edges without changing the underlying probability distribution into a SAC-style squashed Gaussian.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant