diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index 1410831f..70c5a1ca 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -14,6 +14,7 @@ """PPO networks.""" +import warnings from typing import Any, Literal, Mapping, Sequence, Tuple from brax.training import distribution @@ -116,6 +117,23 @@ def make_ppo_networks( value_kernel_init_kwargs = value_network_kernel_init_kwargs or {} mean_kernel_init_kwargs_ = mean_kernel_init_kwargs or {} + if distribution_type == 'tanh_normal': + ignored = [] + if init_noise_std != 1.0: + ignored.append(f'init_noise_std={init_noise_std!r}') + if noise_std_type != 'scalar': + ignored.append(f'noise_std_type={noise_std_type!r}') + if state_dependent_std: + ignored.append(f'state_dependent_std={state_dependent_std!r}') + if ignored: + warnings.warn( + f'{", ".join(ignored)} {"has" if len(ignored) == 1 else "have"}' + ' no effect with distribution_type="tanh_normal". The standard' + ' deviation is determined entirely by the policy network output.' + ' These parameters only apply to distribution_type="normal".', + stacklevel=2, + ) + parametric_action_distribution: distribution.ParametricDistribution if distribution_type == 'normal': parametric_action_distribution = distribution.NormalDistribution( @@ -163,4 +181,4 @@ def make_ppo_networks( policy_network=policy_network, value_network=value_network, parametric_action_distribution=parametric_action_distribution, - ) + ) \ No newline at end of file