Skip to content

Commit a6cd73e

Browse files
committed
Enable FabricFrameView on non-primary GPUs
- Allow FabricFrameView to run on cuda:N for any N; USDRT SelectPrims no longer needs cuda:0. - Refactor the Fabric write path into a single _compose_fabric_transform helper shared by set_world_poses, set_scales, and the initial USD->Fabric sync, collapsing the sync to one kernel launch with one PrepareForReuse. - Replace the topology-invariant assert with RuntimeError so it survives python -O. - Add multi_gpu pytest marker plus cuda:1 unit-test coverage for both Fabric write paths, and run them in the existing test-multi-gpu CI job (one extra step, no new job).
1 parent efd9d1e commit a6cd73e

5 files changed

Lines changed: 154 additions & 72 deletions

File tree

.github/workflows/test-multi-gpu.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ on:
2020
- "source/isaaclab/isaaclab/app/app_launcher.py"
2121
- "source/isaaclab_tasks/isaaclab_tasks/utils/sim_launcher.py"
2222
- "scripts/reinforcement_learning/**/train.py"
23+
- "source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py"
24+
- "source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py"
2325
- ".github/workflows/test-multi-gpu.yaml"
2426
workflow_dispatch:
2527

@@ -104,6 +106,13 @@ jobs:
104106
exit 1
105107
fi
106108
109+
- name: Run FabricFrameView multi-GPU unit tests
110+
# Cheap (~tens of seconds) and only depends on torch/warp + 2 GPUs,
111+
# so run once per matrix entry rather than carving out a separate job.
112+
run: |
113+
./isaaclab.sh -p -m pytest -m multi_gpu \
114+
source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py -v
115+
107116
- name: Run multi-GPU training (${{ matrix.physics }}, ${{ matrix.renderer }})
108117
env:
109118
NCCL_DEBUG: WARN

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ ignore-words-list = "haa,slq,collapsable,buss,reacher,thirdparty"
139139

140140
markers = [
141141
"isaacsim_ci: mark test to run in isaacsim ci",
142+
"multi_gpu: tests that require 2+ GPUs; skipped automatically on single-GPU machines",
142143
]
143144

144145
# Add pypi.nvidia.com so that `uv pip install isaaclab[isaacsim]` works without --extra-index-url.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
Changed
2+
^^^^^^^
3+
4+
* Combined the initial USD→Fabric sync in
5+
:class:`~isaaclab_physx.sim.views.FabricFrameView` into a single Fabric
6+
write so ``PrepareForReuse`` is invoked exactly once per logical update
7+
(positions, orientations, and scales are composed in one kernel launch).
8+
This avoids the possibility of a second non-idempotent
9+
``PrepareForReuse`` call masking a topology-change signal that should
10+
have triggered a fabricarray rebuild.
11+
12+
Fixed
13+
^^^^^
14+
15+
* Fixed :class:`~isaaclab_physx.sim.views.FabricFrameView` falling back to
16+
the slow USD path on every CUDA device other than ``cuda:0``. USDRT
17+
``SelectPrims`` now accepts any CUDA device index, so Fabric acceleration
18+
runs on the simulation device the view was constructed with (e.g.
19+
``cuda:1``). This unblocks distributed training where each rank is
20+
pinned to a non-primary GPU.
21+
22+
* Fixed the topology-change invariant guard in
23+
:class:`~isaaclab_physx.sim.views.FabricFrameView` not surviving
24+
``python -O``. The check now raises :class:`RuntimeError` instead of
25+
using ``assert`` so the prim-count mismatch between view and Fabric is
26+
reported at every optimisation level rather than silently producing
27+
wrong poses or out-of-bounds kernel indices.

source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py

Lines changed: 29 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26-
# TODO: extend this to ``cuda:N`` once we wire up multi-GPU support for the view.
27-
# Recent Kit / USDRT releases do support multi-GPU ``SelectPrims``, but the
28-
# rest of the FabricFrameView wiring (selections, indexed arrays, etc.) still
29-
# assumes a single device — to be tackled in a follow-up.
30-
_fabric_supported_devices = ("cpu", "cuda", "cuda:0")
31-
3226

3327
def _to_float32_2d(a: wp.array | torch.Tensor) -> wp.array | torch.Tensor:
3428
"""Ensure array is compatible with Fabric kernels (2-D float32).
@@ -92,15 +86,6 @@ def __init__(
9286
settings = SettingsManager.instance()
9387
self._use_fabric = bool(settings.get("/physics/fabricEnabled", False))
9488

95-
if self._use_fabric and self._device not in _fabric_supported_devices:
96-
logger.warning(
97-
f"Fabric mode is not supported on device '{self._device}'. "
98-
"USDRT SelectPrims and Warp fabric arrays are currently "
99-
f"only supported on {', '.join(_fabric_supported_devices)}. "
100-
"Falling back to standard USD operations. This may impact performance."
101-
)
102-
self._use_fabric = False
103-
10489
self._fabric_initialized = False
10590
self._fabric_usd_sync_done = False
10691
self._fabric_selection = None
@@ -149,43 +134,7 @@ def set_world_poses(self, positions=None, orientations=None, indices=None):
149134
if not self._use_fabric:
150135
self._usd_view.set_world_poses(positions, orientations, indices)
151136
return
152-
153-
if not self._fabric_initialized:
154-
self._initialize_fabric()
155-
156-
self._prepare_for_reuse()
157-
158-
indices_wp = self._resolve_indices_wp(indices)
159-
count = indices_wp.shape[0]
160-
161-
dummy = wp.zeros((0, 3), dtype=wp.float32, device=self._device)
162-
positions_wp = _to_float32_2d(positions) if positions is not None else dummy
163-
orientations_wp = (
164-
_to_float32_2d(orientations)
165-
if orientations is not None
166-
else wp.zeros((0, 4), dtype=wp.float32, device=self._device)
167-
)
168-
169-
wp.launch(
170-
kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays,
171-
dim=count,
172-
inputs=[
173-
self._fabric_world_matrices,
174-
positions_wp,
175-
orientations_wp,
176-
dummy,
177-
False,
178-
False,
179-
False,
180-
indices_wp,
181-
self._view_to_fabric,
182-
],
183-
device=self._fabric_device,
184-
)
185-
wp.synchronize()
186-
187-
self._fabric_hierarchy.update_world_xforms()
188-
self._fabric_usd_sync_done = True
137+
self._compose_fabric_transform(positions=positions, orientations=orientations, indices=indices)
189138

190139
def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]:
191140
if not self._use_fabric:
@@ -244,7 +193,15 @@ def set_scales(self, scales, indices=None):
244193
if not self._use_fabric:
245194
self._usd_view.set_scales(scales, indices)
246195
return
196+
self._compose_fabric_transform(scales=scales, indices=indices)
197+
198+
def _compose_fabric_transform(self, positions=None, orientations=None, scales=None, indices=None):
199+
"""Write the given subset of (position, orientation, scale) into Fabric in one kernel launch.
247200
201+
Components left as ``None`` are skipped via empty input arrays — the kernel reads them
202+
from the existing Fabric matrix. Always invokes :meth:`_prepare_for_reuse` exactly once
203+
per write, even when multiple components are updated together.
204+
"""
248205
if not self._fabric_initialized:
249206
self._initialize_fabric()
250207

@@ -253,17 +210,19 @@ def set_scales(self, scales, indices=None):
253210
indices_wp = self._resolve_indices_wp(indices)
254211
count = indices_wp.shape[0]
255212

256-
dummy3 = wp.zeros((0, 3), dtype=wp.float32, device=self._device)
257-
dummy4 = wp.zeros((0, 4), dtype=wp.float32, device=self._device)
258-
scales_wp = _to_float32_2d(scales)
213+
empty3 = wp.zeros((0, 3), dtype=wp.float32, device=self._device)
214+
empty4 = wp.zeros((0, 4), dtype=wp.float32, device=self._device)
215+
positions_wp = _to_float32_2d(positions) if positions is not None else empty3
216+
orientations_wp = _to_float32_2d(orientations) if orientations is not None else empty4
217+
scales_wp = _to_float32_2d(scales) if scales is not None else empty3
259218

260219
wp.launch(
261220
kernel=fabric_utils.compose_fabric_transformation_matrix_from_warp_arrays,
262221
dim=count,
263222
inputs=[
264223
self._fabric_world_matrices,
265-
dummy3,
266-
dummy4,
224+
positions_wp,
225+
orientations_wp,
267226
scales_wp,
268227
False,
269228
False,
@@ -347,10 +306,11 @@ def _rebuild_fabric_arrays(self) -> None:
347306
pattern (via ``_usd_view.count``) and does not change when Fabric rearranges its
348307
internal memory layout. The assertion below guards this invariant.
349308
"""
350-
assert self.count == self._default_view_indices.shape[0], (
351-
f"Prim count changed ({self.count} vs {self._default_view_indices.shape[0]}). "
352-
"Fabric topology change added/removed tracked prims — full re-initialization required."
353-
)
309+
if self.count != self._default_view_indices.shape[0]:
310+
raise RuntimeError(
311+
f"Prim count changed ({self.count} vs {self._default_view_indices.shape[0]}). "
312+
"Fabric topology change added/removed tracked prims — full re-initialization required."
313+
)
354314
self._view_to_fabric = wp.zeros((self.count,), dtype=wp.uint32, device=self._fabric_device)
355315
self._fabric_to_view = wp.fabricarray(self._fabric_selection, self._view_index_attr)
356316

@@ -404,9 +364,6 @@ def _initialize_fabric(self) -> None:
404364
)
405365
wp.synchronize()
406366

407-
# The constructor should have taken care of this, but double check here to avoid regressions
408-
assert self._device in _fabric_supported_devices
409-
410367
self._fabric_selection = fabric_stage.SelectPrims(
411368
require_attrs=[
412369
(usdrt.Sdf.ValueTypeNames.UInt, self._view_index_attr, usdrt.Usd.Access.Read),
@@ -442,19 +399,20 @@ def _initialize_fabric(self) -> None:
442399
def _sync_fabric_from_usd_once(self) -> None:
443400
"""Sync Fabric world matrices from USD once, on the first read.
444401
445-
``set_world_poses`` and ``set_scales`` each set ``_fabric_usd_sync_done``
446-
themselves, so no explicit flag assignment is needed here.
402+
Combines position/orientation/scale into a single Fabric write so
403+
:meth:`_prepare_for_reuse` (and its underlying ``PrepareForReuse``) is invoked
404+
exactly once across the full sync.
447405
"""
448406
if not self._fabric_initialized:
449407
self._initialize_fabric()
450408

451409
positions_usd_ta, orientations_usd_ta = self._usd_view.get_world_poses()
452-
positions_usd = positions_usd_ta.warp
453-
orientations_usd = orientations_usd_ta.warp
454410
scales_usd = self._usd_view.get_scales()
455-
456-
self.set_world_poses(positions_usd, orientations_usd)
457-
self.set_scales(scales_usd)
411+
self._compose_fabric_transform(
412+
positions=positions_usd_ta.warp,
413+
orientations=orientations_usd_ta.warp,
414+
scales=scales_usd,
415+
)
458416

459417
def _resolve_indices_wp(self, indices: wp.array | None) -> wp.array:
460418
"""Resolve view indices as a Warp uint32 array."""

source/isaaclab_physx/test/sim/test_views_xform_prim_fabric.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Camera prim type for Fabric SelectPrims compatibility).
1111
"""
1212

13+
import os
1314
import sys
1415
from pathlib import Path
1516

@@ -44,8 +45,17 @@ def test_setup_teardown():
4445

4546

4647
def _skip_if_unavailable(device: str):
47-
if device.startswith("cuda") and not torch.cuda.is_available():
48+
if not device.startswith("cuda"):
49+
return
50+
if not torch.cuda.is_available():
4851
pytest.skip("CUDA not available")
52+
idx = int(device.split(":")[1]) if ":" in device else 0
53+
n = torch.cuda.device_count()
54+
if idx >= n:
55+
msg = f"{device} not available (device_count={n})"
56+
if os.environ.get("GITHUB_ACTIONS") == "true":
57+
pytest.fail(f"{msg} — multi-GPU runner is misconfigured")
58+
pytest.skip(f"{msg} — multi-GPU test skipped on single-GPU machine")
4959

5060

5161
# ------------------------------------------------------------------
@@ -233,3 +243,80 @@ def force_topology_changed():
233243
pos_torch = wp.to_torch(ret_pos)
234244
expected = torch.tensor([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]], device=device)
235245
assert torch.allclose(pos_torch, expected, atol=1e-7), f"Read after rebuild failed on {device}: {pos_torch}"
246+
247+
248+
# ------------------------------------------------------------------
249+
# Multi-GPU tests (cuda:1) — skipped automatically on single-GPU workstations
250+
# ------------------------------------------------------------------
251+
252+
253+
@pytest.mark.multi_gpu
254+
@pytest.mark.parametrize("device", ["cuda:1"])
255+
def test_fabric_cuda1_world_pose_roundtrip(device, view_factory):
256+
"""set_world_poses -> get_world_poses roundtrip works on cuda:1.
257+
258+
Verifies that FabricFrameView operates correctly on a non-primary CUDA
259+
device without falling back to the USD path.
260+
"""
261+
bundle = view_factory(2, device)
262+
view = bundle.view
263+
264+
new_pos = wp.zeros((2, 3), dtype=wp.float32, device=device)
265+
wp.launch(kernel=_fill_position, dim=2, inputs=[new_pos, 10.0, 20.0, 30.0], device=device)
266+
view.set_world_poses(positions=new_pos)
267+
268+
ret_pos, _ = view.get_world_poses()
269+
pos_torch = wp.to_torch(ret_pos)
270+
expected = torch.tensor([[10.0, 20.0, 30.0], [10.0, 20.0, 30.0]], device=device)
271+
assert torch.allclose(pos_torch, expected, atol=1e-7), f"Roundtrip failed on {device}: {pos_torch}"
272+
273+
274+
@pytest.mark.multi_gpu
275+
@pytest.mark.parametrize("device", ["cuda:1"])
276+
def test_fabric_cuda1_no_usd_writeback(device, view_factory):
277+
"""set_world_poses on cuda:1 does not write back to USD.
278+
279+
Mirrors test_fabric_set_world_does_not_write_back_to_usd for the cuda:1
280+
device to confirm the no-writeback invariant holds across GPU indices.
281+
"""
282+
bundle = view_factory(1, device)
283+
view = bundle.view
284+
285+
stage = sim_utils.get_current_stage()
286+
prim = stage.GetPrimAtPath(view.prim_paths[0])
287+
xform_cache = UsdGeom.XformCache()
288+
t_before = xform_cache.GetLocalToWorldTransform(prim).ExtractTranslation()
289+
orig_usd_pos = torch.tensor([float(t_before[0]), float(t_before[1]), float(t_before[2])])
290+
291+
new_pos = wp.zeros((1, 3), dtype=wp.float32, device=device)
292+
wp.launch(kernel=_fill_position, dim=1, inputs=[new_pos, 99.0, 99.0, 99.0], device=device)
293+
view.set_world_poses(positions=new_pos)
294+
295+
# USD must not have moved at all — equality, not approximate.
296+
t_after = UsdGeom.XformCache().GetLocalToWorldTransform(prim).ExtractTranslation()
297+
usd_pos_after = torch.tensor([float(t_after[0]), float(t_after[1]), float(t_after[2])])
298+
assert torch.allclose(usd_pos_after, orig_usd_pos, atol=0.0), (
299+
f"USD wrote back on {device}: expected {orig_usd_pos}, got {usd_pos_after}"
300+
)
301+
302+
303+
@pytest.mark.multi_gpu
304+
@pytest.mark.parametrize("device", ["cuda:1"])
305+
def test_fabric_cuda1_scales_roundtrip(device, view_factory):
306+
"""set_scales -> get_scales roundtrip works on cuda:1.
307+
308+
Both write paths (``set_world_poses`` and ``set_scales``) call
309+
``_prepare_for_reuse`` and launch on ``self._device``; this test covers
310+
the scales path on the non-primary CUDA device.
311+
"""
312+
bundle = view_factory(2, device)
313+
view = bundle.view
314+
315+
new_scales = wp.zeros((2, 3), dtype=wp.float32, device=device)
316+
wp.launch(kernel=_fill_position, dim=2, inputs=[new_scales, 2.0, 3.0, 4.0], device=device)
317+
view.set_scales(new_scales)
318+
319+
ret_scales = view.get_scales()
320+
scales_torch = wp.to_torch(ret_scales)
321+
expected = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 4.0]], device=device)
322+
assert torch.allclose(scales_torch, expected, atol=1e-7), f"Scales roundtrip failed on {device}: {scales_torch}"

0 commit comments

Comments
 (0)