Skip to content

Commit 80e48f7

Browse files
hujc7AntoineRichardkellyguo11
authored
[Exp] Cherry-pick warp MDP migration and capture safety from dev/newton (#4945)
## Summary * Cherry-picks [Newton] Migrate more envs and mdps to warp (#4690) onto develop * Cherry-picks [Newton] Add capture safety guards and fix WrenchComposer stale COM pose (#4779) onto develop ### Changes included - Warp-first MDP terms (observations, rewards, events, terminations, actions) for manager-based envs - Tested warp env configs: Ant, Humanoid, Cartpole, locomotion velocity (A1, AnymalB/C/D, Cassie, G1, Go1/2, H1), Franka/UR10 reach - ManagerCallSwitch max_mode cap and scene capture config - MDP kernels made graph-capturable with consolidated test infrastructure - capture_unsafe safety guards on lazy-evaluated derived properties in articulation/rigid_object data - WrenchComposer fix: use fresh COM pose buffers instead of stale cached link poses ### Dropped - G1-29-DOF warp env (Isaac-Velocity-Flat-G1-Warp-v1): removed because the stable g1_29_dofs task config does not exist on develop (only on dev/newton). Warp env PRs should only add warp frontends for envs that already exist in the stable package. ## Dependencies Must be merged **after** these PRs (in order): 1. #4905 (merged) 2. #4829 ## Validated base Validated against develop at 7588fa9. ## Test plan - [x] Run warp env training sweep across all manager-based env configs (14/14 pass, mode=2, 4096 envs, 300 iters) - [ ] Run test_mdp_warp_parity.py and test_mdp_warp_parity_new_terms.py - [ ] Run test_action_warp_parity.py - [ ] Verify WrenchComposer COM pose is fresh (not stale) during graph replay --------- Co-authored-by: Antoine Richard <antoiner@nvidia.com> Co-authored-by: Kelly Guo <kellyg@nvidia.com>
1 parent b69279d commit 80e48f7

85 files changed

Lines changed: 8911 additions & 168 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
import functools
9+
from collections.abc import Sequence
10+
11+
import torch
12+
import warp as wp
13+
14+
##
15+
# Mask resolution - ids/mask to warp boolean mask.
16+
##
17+
18+
19+
@wp.kernel
20+
def _populate_mask_from_ids(
21+
mask: wp.array(dtype=wp.bool),
22+
ids: wp.array(dtype=wp.int32),
23+
):
24+
i = wp.tid()
25+
mask[ids[i]] = True
26+
27+
28+
def resolve_1d_mask(
29+
*,
30+
ids: Sequence[int] | slice | torch.Tensor | wp.array | None = None,
31+
mask: wp.array | torch.Tensor | None = None,
32+
all_mask: wp.array,
33+
scratch_mask: wp.array,
34+
device: str,
35+
) -> wp.array:
36+
"""Resolve ids/mask into a warp boolean mask.
37+
38+
Callers provide pre-allocated ``all_mask`` (all-True) and ``scratch_mask`` (reusable
39+
work buffer) so this function never allocates.
40+
41+
Args:
42+
ids: Index ids. Accepts ``Sequence[int]``, ``slice``, ``torch.Tensor``,
43+
``wp.array(dtype=wp.int32)``, or ``None`` (all elements).
44+
mask: Direct boolean mask. ``wp.array`` is returned as-is;
45+
``torch.Tensor`` is converted.
46+
all_mask: Pre-allocated all-True mask returned when both *ids* and *mask*
47+
are ``None``.
48+
scratch_mask: Pre-allocated scratch buffer populated in-place when *ids*
49+
are provided. Not re-entrant (shared buffer).
50+
device: Warp device string (e.g. ``"cuda:0"``).
51+
52+
Returns:
53+
A ``wp.array(dtype=wp.bool)`` mask.
54+
"""
55+
# Normalize slice(None) to None so the capture guard treats it identically to ids=None.
56+
if isinstance(ids, slice) and ids == slice(None):
57+
ids = None
58+
59+
if wp.get_device().is_capturing:
60+
if ids is not None or (mask is not None and not isinstance(mask, wp.array)):
61+
raise RuntimeError(
62+
"resolve_1d_mask is only capturable when mask is a wp.array or both ids and mask are None."
63+
)
64+
65+
# --- Direct mask input ---
66+
if mask is not None:
67+
if isinstance(mask, wp.array):
68+
return mask
69+
if isinstance(mask, torch.Tensor):
70+
if mask.dtype != torch.bool:
71+
mask = mask.to(dtype=torch.bool)
72+
if str(mask.device) != device:
73+
mask = mask.to(device)
74+
return wp.from_torch(mask, dtype=wp.bool)
75+
raise TypeError(f"Unsupported mask type: {type(mask)}")
76+
77+
# --- Fast path: all elements ---
78+
if ids is None:
79+
return all_mask
80+
81+
# --- Normalize slice to list ---
82+
if isinstance(ids, slice):
83+
start, stop, step = ids.indices(scratch_mask.shape[0])
84+
ids = list(range(start, stop, step))
85+
86+
# --- Normalize to concrete type ---
87+
if not isinstance(ids, (torch.Tensor, wp.array)):
88+
ids = list(ids)
89+
90+
# --- Populate scratch mask ---
91+
scratch_mask.fill_(False)
92+
93+
if isinstance(ids, torch.Tensor):
94+
if ids.numel() == 0:
95+
return scratch_mask
96+
if str(ids.device) != device:
97+
ids = ids.to(device)
98+
if ids.dtype != torch.int32:
99+
ids = ids.to(dtype=torch.int32)
100+
if not ids.is_contiguous():
101+
ids = ids.contiguous()
102+
ids_wp = wp.from_torch(ids, dtype=wp.int32)
103+
elif isinstance(ids, wp.array):
104+
if ids.shape[0] == 0:
105+
return scratch_mask
106+
if ids.dtype != wp.int32:
107+
raise TypeError(f"Unsupported wp.array dtype for ids: {ids.dtype}. Expected wp.int32 index array.")
108+
ids_wp = ids
109+
else:
110+
if len(ids) == 0:
111+
return scratch_mask
112+
ids_wp = wp.array(ids, dtype=wp.int32, device=device)
113+
114+
wp.launch(_populate_mask_from_ids, dim=ids_wp.shape[0], inputs=[scratch_mask, ids_wp], device=device)
115+
return scratch_mask
116+
117+
118+
##
119+
# Capture safety — property guard.
120+
##
121+
122+
123+
def capture_unsafe(reason: str | None = None):
124+
"""Mark a callable as not CUDA-graph-capture-safe.
125+
126+
Raises ``RuntimeError`` if the decorated callable is invoked while
127+
``wp.get_device().is_capturing`` is ``True``.
128+
129+
Args:
130+
reason: Optional explanation appended to the error message.
131+
132+
Usage::
133+
134+
@property
135+
@capture_unsafe("Relies on a Python timestamp guard.")
136+
def projected_gravity_b(self) -> wp.array: ...
137+
"""
138+
139+
def decorator(func):
140+
@functools.wraps(func)
141+
def wrapper(*args, **kwargs):
142+
if wp.get_device().is_capturing:
143+
msg = f"'{func.__qualname__}' cannot be called during CUDA graph capture."
144+
if reason:
145+
msg = f"{msg} {reason}"
146+
raise RuntimeError(msg)
147+
return func(*args, **kwargs)
148+
149+
return wrapper
150+
151+
return decorator

source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def __init__(self, cfg: ManagerBasedEnvCfg):
7676
self.cfg = cfg
7777
# initialize internal variables
7878
self._is_closed = False
79-
self._manager_call_switch = ManagerCallSwitch()
79+
# temporary debug runtime config for manager source/call switching.
80+
cfg_source: dict | str | None = getattr(self.cfg, "manager_call_config", None)
81+
max_modes: dict[str, int] | None = getattr(self.cfg, "manager_call_max_mode", None)
82+
self._manager_call_switch = ManagerCallSwitch(cfg_source, max_modes=max_modes)
8083
self._apply_manager_term_cfg_profile()
8184

8285
# set the seed for the environment
@@ -265,6 +268,17 @@ def device(self):
265268
"""The device on which the environment is running."""
266269
return self.sim.device
267270

271+
@property
272+
def env_origins_wp(self) -> wp.array:
273+
"""Scene env origins as a warp ``vec3f`` array. Cached on first access."""
274+
if not hasattr(self, "_env_origins_wp"):
275+
origins = self.scene.env_origins
276+
if isinstance(origins, wp.array):
277+
self._env_origins_wp = origins
278+
else:
279+
self._env_origins_wp = wp.from_torch(origins, dtype=wp.vec3f)
280+
return self._env_origins_wp
281+
268282
def resolve_env_mask(
269283
self,
270284
*,

0 commit comments

Comments
 (0)