diff --git a/test/test_exploration.py b/test/test_exploration.py index b9ddda702d9..7968b720164 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -185,7 +185,7 @@ def test_no_spec_error(self): with pytest.raises( RuntimeError, - match="Failed while executing module|spec must be provided to the exploration wrapper", + match="Failed while executing module|spec has not been set", ): explorative_policy(td) @@ -421,8 +421,12 @@ def test_nested( return def test_no_spec_error(self, device): - with pytest.raises(RuntimeError, match="spec cannot be None."): - OrnsteinUhlenbeckProcessModule(spec=None).to(device) + module = OrnsteinUhlenbeckProcessModule(spec=None, safe=False).to(device) + td = TensorDict( + {"action": torch.randn(3, device=device)}, batch_size=[3], device=device + ) + out = module(td) + assert "action" in out.keys() @pytest.mark.parametrize("device", get_default_devices()) @@ -683,7 +687,7 @@ def test_set_exploration_modules_spec_from_env(device, use_batched_env): d_obs = env.observation_spec["observation"].shape[-1] d_act = expected_spec.shape[-1] - # Create a policy with exploration module that has spec=None + # Create a policy with exploration modules that have spec=None net = nn.Sequential( nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor() ) @@ -699,19 +703,29 @@ def test_set_exploration_modules_spec_from_env(device, use_batched_env): distribution_class=TanhNormal, default_interaction_type=InteractionType.RANDOM, ).to(device) - exploration_module = AdditiveGaussianModule(spec=None, device=device) - exploratory_policy = TensorDictSequential(policy, exploration_module) + additive = AdditiveGaussianModule(spec=None, device=device) + egreedy = EGreedyModule(spec=None, device=device) + ou = OrnsteinUhlenbeckProcessModule(spec=None, device=device) + exploratory_policy = TensorDictSequential(policy, additive, egreedy, ou) - assert exploration_module._spec is None + assert additive.spec is None + assert egreedy.spec is None + assert ou.spec is None set_exploration_modules_spec_from_env(exploratory_policy, env) # Verify spec is set after configuration and matches the environment's action_spec - assert exploration_module._spec is not None - if isinstance(exploration_module._spec, Composite): - assert exploration_module._spec[exploration_module.action_key] == expected_spec - else: - assert exploration_module._spec == expected_spec + for exploration_module in (additive, egreedy, ou): + assert exploration_module.spec is not None + if isinstance(exploration_module.spec, Composite): + action_key = ( + exploration_module.action_key + if hasattr(exploration_module, "action_key") + else exploration_module.ou.key + ) + assert exploration_module.spec[action_key] == expected_spec + else: + assert exploration_module.spec == expected_spec td = env.reset() result = exploratory_policy(td) diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index e6638b32296..0bdbdfa2d86 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -93,7 +93,7 @@ class EGreedyModule(TensorDictModuleBase): def __init__( self, - spec: TensorSpec, + spec: TensorSpec | None, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, @@ -123,17 +123,22 @@ def __init__( "eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device) ) - if spec is not None: - if not isinstance(spec, Composite) and len(self.out_keys) >= 1: - spec = Composite({action_key: spec}, shape=spec.shape[:-1]) - if device is not None: - spec = spec.to(device) - self._spec = spec + self.spec = spec @property def spec(self): return self._spec + @spec.setter + def spec(self, value: TensorSpec | None) -> None: + if value is not None: + if not isinstance(value, Composite) and len(self.out_keys) >= 1: + value = Composite({self.action_key: value}, shape=value.shape[:-1]) + if self.eps.device is not None: + value = value.to(self.eps.device) + + self._spec = value + def step(self, frames: int = 1) -> None: """A step of epsilon decay. @@ -203,7 +208,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: r = r.to(device) action = torch.where(cond, r, action) else: - raise RuntimeError("spec must be provided to the exploration wrapper.") + raise RuntimeError( + "spec has not been set. Pass spec at construction time or set it via " + "the `spec` property before calling forward()." + ) action_tensordict.set(action_key, action) return tensordict @@ -518,7 +526,7 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase): def __init__( self, - spec: TensorSpec, + spec: TensorSpec | None, eps_init: float = 1.0, eps_end: float = 0.1, annealing_num_steps: int = 1000, @@ -564,20 +572,7 @@ def __init__( self.in_keys = [self.ou.key] self.out_keys = [self.ou.key] + self.ou.out_keys self.is_init_key = is_init_key - noise_key = self.ou.noise_key - steps_key = self.ou.steps_key - - if spec is not None: - if not isinstance(spec, Composite) and len(self.out_keys) >= 1: - spec = Composite({action_key: spec}, shape=spec.shape[:-1]) - self._spec = spec - else: - raise RuntimeError("spec cannot be None.") - ou_specs = { - noise_key: None, - steps_key: None, - } - self._spec.update(ou_specs) + self.spec = spec if len(set(self.out_keys)) != len(self.out_keys): raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}") self.safe = safe @@ -588,6 +583,21 @@ def __init__( def spec(self): return self._spec + @spec.setter + def spec(self, value: TensorSpec | None) -> None: + if value is None: + self._spec = None + return + if not isinstance(value, Composite) and len(self.out_keys) >= 1: + value = Composite({self.ou.key: value}, shape=value.shape[:-1]) + ou_specs = { + self.ou.noise_key: None, + self.ou.steps_key: None, + } + value = value.clone() + value.update(ou_specs) + self._spec = value + def step(self, frames: int = 1) -> None: """Updates the eps noise factor. @@ -828,8 +838,8 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase: def set_exploration_modules_spec_from_env(policy: nn.Module, env: EnvBase) -> None: """Sets exploration module specs from an environment action spec. - This is intended for cases where exploration modules (e.g. AdditiveGaussianModule) - are instantiated with ``spec=None`` and must be configured once the environment + This is intended for cases where exploration modules (e.g. AdditiveGaussianModule, + EGreedyModule, OrnsteinUhlenbeckProcessModule) are instantiated with ``spec=None`` and must be configured once the environment is known (e.g. inside a collector). """ action_spec = ( @@ -838,6 +848,13 @@ def set_exploration_modules_spec_from_env(policy: nn.Module, env: EnvBase) -> No else env.action_spec ) + exploration_modules = ( + AdditiveGaussianModule, + EGreedyModule, + OrnsteinUhlenbeckProcessModule, + ) + for submodule in policy.modules(): - if isinstance(submodule, AdditiveGaussianModule) and submodule._spec is None: - submodule.spec = action_spec + if isinstance(submodule, exploration_modules): + if submodule.spec is None: + submodule.spec = action_spec