diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index dfedd118219..5544e48c0bb 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -376,6 +376,9 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): are to be expected. Can be ``None``. is_spec_locked (bool): returns ``True`` if the specs are locked. See the :attr:`spec_locked` argument above. + _skip_maybe_reset (bool): if ``True``, :meth:`step_and_maybe_reset` will skip calling + :meth:`maybe_reset` after each step. This is useful for auto-resetting environments + that already handle resets inside their :meth:`_step` method. Defaults to ``False``. Methods: step (TensorDictBase -> TensorDictBase): step in the environment @@ -477,6 +480,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): _batch_size: torch.Size | None _device: torch.device | None _is_spec_locked: bool = False + _skip_maybe_reset: bool = False def __init__( self, @@ -3829,7 +3833,8 @@ def step_and_maybe_reset( tensordict_ = self._step_mdp(tensordict) # if self._post_step_mdp_hooks is not None: # tensordict_ = self._post_step_mdp_hooks(tensordict_) - tensordict_ = self.maybe_reset(tensordict_) + if not self._skip_maybe_reset: + tensordict_ = self.maybe_reset(tensordict_) return tensordict, tensordict_ # _post_step_mdp_hooks: Callable[[TensorDictBase], TensorDictBase] | None = None