Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
_skip_maybe_reset (bool): if ``True``, :meth:`step_and_maybe_reset` will skip calling
:meth:`maybe_reset` after each step. This is useful for auto-resetting environments
that already handle resets inside their :meth:`_step` method. Defaults to ``False``.
_trust_step_output (bool): if ``True``, :meth:`step` will skip the :meth:`_step_proc_data`
validation (reward shape checks, done-key completion, type checks) after :meth:`_step`.
Set this when the environment guarantees that its :meth:`_step` output always has correct
shapes, all done keys present, and proper dtypes. Defaults to ``False``.
Comment on lines +382 to +385
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a footgun. I'm kind of ok with it but looking at other PRs down the stack we MUST make sure that all the side effects are super documented, like lack of partial steps support (which require control flow) and such. This should also be marked with a massive "Experimental" flag for anyone who wants to play with it.


Methods:
step (TensorDictBase -> TensorDictBase): step in the environment
Expand Down Expand Up @@ -481,6 +485,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
_device: torch.device | None
_is_spec_locked: bool = False
_skip_maybe_reset: bool = False
_trust_step_output: bool = False

def __init__(
self,
Expand Down Expand Up @@ -2245,7 +2250,8 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:

if next_tensordict is None:
next_tensordict = self._step(tensordict)
next_tensordict = self._step_proc_data(next_tensordict)
if not self._trust_step_output:
next_tensordict = self._step_proc_data(next_tensordict)
if next_preset is not None:
# tensordict could already have a "next" key
# this could be done more efficiently by not excluding but just passing
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,8 +1263,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
next_tensordict.update(
next_preset.exclude(*next_tensordict.keys(True, True))
)
self.base_env._complete_done(self.base_env.full_done_spec, next_tensordict)
# we want the input entries to remain unchanged
if not self.base_env._trust_step_output:
self.base_env._complete_done(
self.base_env.full_done_spec, next_tensordict
)
next_tensordict = self.transform._step(tensordict_in, next_tensordict)

if partial_steps is not None:
Expand Down
Loading