-
Notifications
You must be signed in to change notification settings - Fork 3.5k
OMPE-88188: Add frame stacking support for explicit temporal info #5574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jmart-nv
wants to merge
3
commits into
isaac-sim:develop
Choose a base branch
from
jmart-nv:jmart/frame-stacking
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| Added | ||
| ^^^^^ | ||
|
|
||
| * Added :class:`~isaaclab.envs.utils.FrameStackBuffer`, a ring buffer that stacks the last | ||
| ``N`` rendered frames along the channel dimension for tasks that need explicit temporal | ||
| observations. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Frame-stacking helper for camera-based RL tasks. | ||
|
|
||
| Provides :class:`FrameStackBuffer`, a ring buffer over the last ``N`` rendered frames | ||
| that tasks can use to supply explicit temporal observations to a policy. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class FrameStackBuffer: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can avoid a This can become:
|
||
| """Ring buffer that stacks the last ``frame_stack`` rendered frames along the channel dim. | ||
|
|
||
| Example:: | ||
|
|
||
| self._stack = FrameStackBuffer( | ||
| single_frame_shape=(self.num_envs, H, W, C), | ||
| frame_stack=self.cfg.frame_stack, | ||
| device=self.device, | ||
| ) | ||
| # in _get_observations: | ||
| stacked = self._stack.update(rgb) | ||
| # in _reset_idx: | ||
| self._stack.reset(env_ids) | ||
|
|
||
| Args: | ||
| single_frame_shape: Shape of one rendered frame, ``(num_envs, H, W, C)``. | ||
| frame_stack: Number of frames to keep. Must be ``>= 1``; ``1`` is a passthrough. | ||
| device: Torch device for the internal buffers. | ||
| dtype: Torch dtype for the internal buffers. Defaults to :obj:`torch.uint8`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| single_frame_shape: tuple[int, ...], | ||
| frame_stack: int, | ||
| device: str | torch.device, | ||
| dtype: torch.dtype = torch.uint8, | ||
| ): | ||
| if frame_stack < 1: | ||
| raise ValueError(f"frame_stack must be >= 1, got {frame_stack}.") | ||
| if len(single_frame_shape) < 2: | ||
| raise ValueError( | ||
| f"single_frame_shape must have at least 2 dims (envs + channels), got {single_frame_shape}." | ||
| ) | ||
| self.frame_stack: int = frame_stack | ||
| self._single_shape: tuple[int, ...] = tuple(int(d) for d in single_frame_shape) | ||
| self._num_envs: int = self._single_shape[0] | ||
| self._channels: int = self._single_shape[-1] | ||
| self._device = torch.device(device) if isinstance(device, str) else device | ||
| self._dtype = dtype | ||
|
|
||
| self._history: torch.Tensor = torch.zeros((frame_stack, *self._single_shape), device=self._device, dtype=dtype) | ||
| self._stacked: torch.Tensor = torch.zeros( | ||
| (*self._single_shape[:-1], self._channels * frame_stack), device=self._device, dtype=dtype | ||
| ) | ||
| self._frame_idx: int = 0 | ||
| self._needs_init: torch.Tensor = torch.ones(self._num_envs, device=self._device, dtype=torch.bool) | ||
| # CPU-side mirror of _needs_init.any() — avoids a GPU→CPU sync on the steady-state path. | ||
| self._needs_init_cpu: bool = True | ||
|
|
||
| @property | ||
| def output_shape(self) -> tuple[int, ...]: | ||
| """Shape of the tensor returned by :meth:`update`, ``(num_envs, H, W, C * frame_stack)``.""" | ||
| return (*self._single_shape[:-1], self._channels * self.frame_stack) | ||
|
|
||
| @property | ||
| def output_channels(self) -> int: | ||
| """Channel count of the stacked output (``= single_channels * frame_stack``).""" | ||
| return self._channels * self.frame_stack | ||
|
|
||
| def update(self, single_frame: torch.Tensor) -> torch.Tensor: | ||
| """Push a new frame and return the stacked output. | ||
|
|
||
| On the first :meth:`update` after construction or :meth:`reset` for an env, all | ||
| history slots for that env are filled with ``single_frame`` so the policy never | ||
| sees zero-padded warmup data. | ||
|
|
||
| Args: | ||
| single_frame: New rendered frame, shape ``(num_envs, H, W, C)``. | ||
|
|
||
| Returns: | ||
| Stacked tensor ``(num_envs, H, W, C * frame_stack)`` in oldest-to-newest | ||
| channel order. This is the buffer's own storage — do not mutate it. | ||
| """ | ||
|
jmart-nv marked this conversation as resolved.
|
||
| if single_frame.shape != self._single_shape: | ||
| raise ValueError( | ||
| f"single_frame shape {tuple(single_frame.shape)} does not match expected " | ||
| f"{self._single_shape} (set at construction)." | ||
| ) | ||
|
|
||
| if self._needs_init_cpu: | ||
|
jmart-nv marked this conversation as resolved.
|
||
| init_ids = self._needs_init.nonzero(as_tuple=False).squeeze(-1) | ||
| if init_ids.numel() > 0: | ||
| for i in range(self.frame_stack): | ||
| self._history[i, init_ids] = single_frame[init_ids] | ||
| self._needs_init.zero_() | ||
| self._needs_init_cpu = False | ||
|
|
||
| self._history[self._frame_idx].copy_(single_frame) | ||
|
|
||
| # narrow + copy_ rebuild avoids per-frame torch.cat allocations. | ||
| for i in range(self.frame_stack): | ||
| src_slot = (self._frame_idx + 1 + i) % self.frame_stack | ||
| self._stacked.narrow(-1, i * self._channels, self._channels).copy_(self._history[src_slot]) | ||
|
|
||
| self._frame_idx = (self._frame_idx + 1) % self.frame_stack | ||
| return self._stacked | ||
|
|
||
| def reset(self, env_ids: Sequence[int] | torch.Tensor | None = None) -> None: | ||
|
jmart-nv marked this conversation as resolved.
|
||
| """Mark envs for history re-initialization on the next :meth:`update`. | ||
|
|
||
| Args: | ||
| env_ids: Indices of envs to reset. ``None`` resets all envs. | ||
| """ | ||
| if env_ids is None: | ||
| self._needs_init.fill_(True) | ||
| else: | ||
| if not isinstance(env_ids, torch.Tensor): | ||
| env_ids = torch.as_tensor(env_ids, device=self._device, dtype=torch.long) | ||
| self._needs_init[env_ids] = True | ||
| self._needs_init_cpu = True | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| # Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Unit tests for :mod:`isaaclab.envs.utils.frame_stack`. Pure tensor logic; no Kit launch.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from isaaclab.envs.utils import FrameStackBuffer | ||
|
|
||
| pytestmark = pytest.mark.isaacsim_ci | ||
|
|
||
| # Shorthand shape (num_envs, H, W, C) used across tests. | ||
| NUM_ENVS = 4 | ||
| HEIGHT = 8 | ||
| WIDTH = 8 | ||
| CHANNELS = 3 | ||
| SINGLE_SHAPE = (NUM_ENVS, HEIGHT, WIDTH, CHANNELS) | ||
|
|
||
|
|
||
| def _make_frame(value: int, dtype: torch.dtype = torch.uint8) -> torch.Tensor: | ||
| """Build a constant-valued (N, H, W, C) tensor on CPU.""" | ||
| return torch.full(SINGLE_SHAPE, value, dtype=dtype) | ||
|
|
||
|
|
||
| class TestFrameStackBuffer: | ||
| """Pure-tensor tests of the ring buffer.""" | ||
|
|
||
| def test_output_shape_and_channels(self): | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=3, device="cpu") | ||
| assert buf.output_shape == (NUM_ENVS, HEIGHT, WIDTH, CHANNELS * 3) | ||
| assert buf.output_channels == CHANNELS * 3 | ||
| # The narrow+copy_ rebuild writes into a single pre-allocated buffer; output must stay contiguous. | ||
| stacked = buf.update(_make_frame(1)) | ||
| assert stacked.is_contiguous() | ||
|
|
||
| def test_init_fills_all_slots_on_first_update(self): | ||
| """First update post-construction fills every history slot with the new frame.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| f0 = _make_frame(7) | ||
| stacked = buf.update(f0) | ||
| # Both slots equal F0. | ||
| assert torch.equal(stacked[..., :CHANNELS], f0) | ||
| assert torch.equal(stacked[..., CHANNELS:], f0) | ||
|
|
||
| def test_ring_buffer_shifts_correctly(self): | ||
| """After the second update, oldest slot = first frame; newest slot = second frame.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| f0 = _make_frame(10) | ||
| f1 = _make_frame(20) | ||
| buf.update(f0) | ||
| stacked = buf.update(f1) | ||
| assert torch.equal(stacked[..., :CHANNELS], f0), "Oldest slot must be the previous frame" | ||
| assert torch.equal(stacked[..., CHANNELS:], f1), "Newest slot must be the latest frame" | ||
|
|
||
| def test_newest_slot_equals_latest_single(self): | ||
| """Ring-buffer correctness invariant: newest slot post-update == the latest single input.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| buf.update(_make_frame(1)) | ||
| f_latest = _make_frame(99) | ||
| stacked = buf.update(f_latest) | ||
| assert torch.equal(stacked[..., CHANNELS:], f_latest) | ||
|
|
||
| def test_three_frame_stack_oldest_to_newest_order(self): | ||
| """frame_stack=3 produces oldest→newest across 3 channel slices.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=3, device="cpu") | ||
| buf.update(_make_frame(10)) # init: all 3 slots = 10 | ||
| buf.update(_make_frame(20)) # slots: [10, 10, 20] | ||
| stacked = buf.update(_make_frame(30)) # slots: [10, 20, 30] | ||
| assert torch.equal(stacked[..., :CHANNELS], _make_frame(10)) | ||
| assert torch.equal(stacked[..., CHANNELS : 2 * CHANNELS], _make_frame(20)) | ||
| assert torch.equal(stacked[..., 2 * CHANNELS :], _make_frame(30)) | ||
|
|
||
| def test_reset_all_envs(self): | ||
| """reset() with no args re-inits every env on the next update.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| buf.update(_make_frame(1)) | ||
| buf.update(_make_frame(2)) # ring filled | ||
| buf.reset() # mark all envs for init | ||
| stacked = buf.update(_make_frame(50)) | ||
| # All slots filled with 50. | ||
| assert torch.equal(stacked[..., :CHANNELS], _make_frame(50)) | ||
| assert torch.equal(stacked[..., CHANNELS:], _make_frame(50)) | ||
|
|
||
| def test_reset_partial_envs_preserves_others(self): | ||
| """Resetting env 0 should re-init only env 0; other envs keep their history.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| buf.update(_make_frame(1)) | ||
| buf.update(_make_frame(2)) | ||
| buf.reset(torch.tensor([0])) | ||
| stacked = buf.update(_make_frame(9)) | ||
| # Env 0: both slots == 9 (init). | ||
| assert torch.equal(stacked[0, ..., :CHANNELS], torch.full((HEIGHT, WIDTH, CHANNELS), 9, dtype=torch.uint8)) | ||
| assert torch.equal(stacked[0, ..., CHANNELS:], torch.full((HEIGHT, WIDTH, CHANNELS), 9, dtype=torch.uint8)) | ||
| # Env 1: oldest == 2 (ring shifted from previous), newest == 9. | ||
| assert torch.equal(stacked[1, ..., :CHANNELS], torch.full((HEIGHT, WIDTH, CHANNELS), 2, dtype=torch.uint8)) | ||
| assert torch.equal(stacked[1, ..., CHANNELS:], torch.full((HEIGHT, WIDTH, CHANNELS), 9, dtype=torch.uint8)) | ||
|
|
||
| def test_frame_stack_one_passthrough(self): | ||
| """frame_stack=1 effectively echoes the input (single-slot ring).""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=1, device="cpu") | ||
| assert buf.output_shape == SINGLE_SHAPE | ||
| f = _make_frame(42) | ||
| stacked = buf.update(f) | ||
| assert torch.equal(stacked, f) | ||
|
|
||
| def test_invalid_frame_stack_raises(self): | ||
| with pytest.raises(ValueError, match="frame_stack must be >= 1"): | ||
| FrameStackBuffer(SINGLE_SHAPE, frame_stack=0, device="cpu") | ||
|
|
||
| def test_invalid_shape_raises(self): | ||
| with pytest.raises(ValueError, match="at least 2 dims"): | ||
| FrameStackBuffer((10,), frame_stack=2, device="cpu") | ||
|
|
||
| def test_wrong_input_shape_raises(self): | ||
| """update() rejects a frame whose shape doesn't match the construction shape.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| with pytest.raises(ValueError, match="does not match expected"): | ||
| buf.update(torch.zeros((NUM_ENVS, HEIGHT, WIDTH, CHANNELS + 1), dtype=torch.uint8)) | ||
|
|
||
| def test_dtype_preserved_uint8(self): | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu", dtype=torch.uint8) | ||
| stacked = buf.update(_make_frame(5)) | ||
| assert stacked.dtype == torch.uint8 | ||
|
|
||
| def test_dtype_preserved_float32(self): | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu", dtype=torch.float32) | ||
| stacked = buf.update(_make_frame(5, dtype=torch.float32)) | ||
| assert stacked.dtype == torch.float32 | ||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available in this env") | ||
| def test_buffer_on_cuda(self): | ||
| """Buffer allocates and operates correctly on a CUDA device.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cuda") | ||
| f0 = torch.full(SINGLE_SHAPE, 7, dtype=torch.uint8, device="cuda") | ||
| stacked = buf.update(f0) | ||
| assert stacked.device.type == "cuda" | ||
| assert stacked.shape == (NUM_ENVS, HEIGHT, WIDTH, CHANNELS * 2) | ||
| # Both slots filled with f0 on the init path. | ||
| assert torch.equal(stacked[..., :CHANNELS], f0) | ||
| assert torch.equal(stacked[..., CHANNELS:], f0) | ||
| # Steady-state shift works on CUDA too. | ||
| f1 = torch.full(SINGLE_SHAPE, 13, dtype=torch.uint8, device="cuda") | ||
| stacked = buf.update(f1) | ||
| assert torch.equal(stacked[..., :CHANNELS], f0) | ||
| assert torch.equal(stacked[..., CHANNELS:], f1) | ||
|
|
||
| def test_long_run_ring_stability(self): | ||
| """After many updates exceeding frame_stack cycles, the oldest-to-newest layout stays correct.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=3, device="cpu") | ||
| # Push 11 frames with values 0..10. After the last update, the ring slots should | ||
| # hold the 3 most-recent frames: [8, 9, 10] in oldest-to-newest order. | ||
| for i in range(11): | ||
| stacked = buf.update(_make_frame(i)) | ||
| assert torch.equal(stacked[..., :CHANNELS], _make_frame(8)) | ||
| assert torch.equal(stacked[..., CHANNELS : 2 * CHANNELS], _make_frame(9)) | ||
| assert torch.equal(stacked[..., 2 * CHANNELS :], _make_frame(10)) | ||
|
|
||
| def test_reset_accepts_python_sequence(self): | ||
| """reset() accepts a plain ``list[int]`` (the type DirectRLEnv hands to ``_reset_idx``).""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| buf.update(_make_frame(1)) | ||
| buf.update(_make_frame(2)) | ||
| buf.reset([0, 2]) | ||
| stacked = buf.update(_make_frame(9)) | ||
| per_env_shape = (HEIGHT, WIDTH, CHANNELS) | ||
| nines = torch.full(per_env_shape, 9, dtype=torch.uint8) | ||
| twos = torch.full(per_env_shape, 2, dtype=torch.uint8) | ||
| for env_id in (0, 2): | ||
| assert torch.equal(stacked[env_id, ..., :CHANNELS], nines), f"env {env_id} oldest" | ||
| assert torch.equal(stacked[env_id, ..., CHANNELS:], nines), f"env {env_id} newest" | ||
| for env_id in (1, 3): | ||
| assert torch.equal(stacked[env_id, ..., :CHANNELS], twos), f"env {env_id} oldest" | ||
| assert torch.equal(stacked[env_id, ..., CHANNELS:], nines), f"env {env_id} newest" | ||
|
|
||
| def test_reset_multi_env_subset_preserves_unrelated(self): | ||
| """Resetting envs [0, 2] should re-init only those; envs [1, 3] keep their history.""" | ||
| buf = FrameStackBuffer(SINGLE_SHAPE, frame_stack=2, device="cpu") | ||
| buf.update(_make_frame(1)) | ||
| buf.update(_make_frame(2)) # ring filled | ||
| buf.reset(torch.tensor([0, 2])) | ||
| stacked = buf.update(_make_frame(9)) | ||
| per_env_shape = (HEIGHT, WIDTH, CHANNELS) | ||
| nines = torch.full(per_env_shape, 9, dtype=torch.uint8) | ||
| twos = torch.full(per_env_shape, 2, dtype=torch.uint8) | ||
| # Reset envs: both slots = 9 (init). | ||
| for env_id in (0, 2): | ||
| assert torch.equal(stacked[env_id, ..., :CHANNELS], nines), f"env {env_id} oldest" | ||
| assert torch.equal(stacked[env_id, ..., CHANNELS:], nines), f"env {env_id} newest" | ||
| # Untouched envs: oldest = 2 (shifted from previous newest), newest = 9. | ||
| for env_id in (1, 3): | ||
| assert torch.equal(stacked[env_id, ..., :CHANNELS], twos), f"env {env_id} oldest" | ||
| assert torch.equal(stacked[env_id, ..., CHANNELS:], nines), f"env {env_id} newest" |
8 changes: 8 additions & 0 deletions
8
source/isaaclab_tasks/changelog.d/jmart-frame-stacking.minor.rst
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| Added | ||
| ^^^^^ | ||
|
|
||
| * Added :class:`~isaaclab_tasks.direct.cartpole.cartpole_camera_presets_env.CartpoleCameraPresetsEnv`, | ||
| a subclass of :class:`~isaaclab_tasks.direct.cartpole.cartpole_camera_env.CartpoleCameraEnv` that | ||
| wires :class:`~isaaclab.envs.utils.FrameStackBuffer` into the ``Isaac-Cartpole-Camera-Presets-Direct-v0`` | ||
| task. ``frame_stack`` defaults to ``2`` for the Newton + Warp combo and ``1`` otherwise; | ||
| CLI overrides via ``env.frame_stack=N`` are respected. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be better placed for https://isaac-sim.github.io/IsaacLab/develop/source/overview/core-concepts/renderers.html since it's a general note around the renderer behavior difference.