diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5544e48c0bb..5a4698346aa 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -379,6 +379,10 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): _skip_maybe_reset (bool): if ``True``, :meth:`step_and_maybe_reset` will skip calling :meth:`maybe_reset` after each step. This is useful for auto-resetting environments that already handle resets inside their :meth:`_step` method. Defaults to ``False``. + _trust_step_output (bool): if ``True``, :meth:`step` will skip the :meth:`_step_proc_data` + validation (reward shape checks, done-key completion, type checks) after :meth:`_step`. + Set this when the environment guarantees that its :meth:`_step` output always has correct + shapes, all done keys present, and proper dtypes. Defaults to ``False``. Methods: step (TensorDictBase -> TensorDictBase): step in the environment @@ -481,6 +485,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): _device: torch.device | None _is_spec_locked: bool = False _skip_maybe_reset: bool = False + _trust_step_output: bool = False def __init__( self, @@ -2245,7 +2250,8 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase: if next_tensordict is None: next_tensordict = self._step(tensordict) - next_tensordict = self._step_proc_data(next_tensordict) + if not self._trust_step_output: + next_tensordict = self._step_proc_data(next_tensordict) if next_preset is not None: # tensordict could already have a "next" key # this could be done more efficiently by not excluding but just passing diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5a4fc241169..1502fd158c4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1263,8 +1263,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: next_tensordict.update( next_preset.exclude(*next_tensordict.keys(True, True)) ) - self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict) - # we want the input entries to remain unchanged + if not self.base_env._trust_step_output: + self.base_env._complete_done( + self.base_env.full_done_spec, next_tensordict + ) next_tensordict = self.transform._step(tensordict_in, next_tensordict) if partial_steps is not None: