Skip to content

Commit 69fa6cd

Browse files
committed
fix: update device compatibility for Fabric mode and simplify device handling
1 parent 43de9b3 commit 69fa6cd

1 file changed

Lines changed: 10 additions & 12 deletions

File tree

source/isaaclab_physx/isaaclab_physx/sim/views/fabric_frame_view.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
# TODO: Currently we don't support multiple GPUs for Fabric-accelerated views
27+
# because USDRT SelectPrims only supported cuda:0 at the time of writing.
28+
_fabric_supported_devices = ("cpu", "cuda", "cuda:0")
29+
2630

2731
def _to_float32_2d(a: wp.array | torch.Tensor) -> wp.array | torch.Tensor:
2832
"""Ensure array is compatible with Fabric kernels (2-D float32).
@@ -75,10 +79,11 @@ def __init__(
7579
settings = SettingsManager.instance()
7680
self._use_fabric = bool(settings.get("/physics/fabricEnabled", False))
7781

78-
if self._use_fabric and self._device not in ("cuda", "cuda:0"):
82+
if self._use_fabric and self._device not in _fabric_supported_devices:
7983
logger.warning(
8084
f"Fabric mode is not supported on device '{self._device}'. "
81-
"USDRT SelectPrims and Warp fabric arrays only support cuda:0. "
85+
"USDRT SelectPrims and Warp fabric arrays are currently "
86+
f"only supported on {', '.join(_fabric_supported_devices)}. "
8287
"Falling back to standard USD operations. This may impact performance."
8388
)
8489
self._use_fabric = False
@@ -386,17 +391,10 @@ def _initialize_fabric(self) -> None:
386391
)
387392
wp.synchronize()
388393

394+
# The constructor should have taken care of this, but double check here to avoid regressions
395+
assert self._device in _fabric_supported_devices
396+
389397
fabric_device = self._device
390-
if self._device == "cuda":
391-
logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.")
392-
fabric_device = "cuda:0"
393-
elif self._device.startswith("cuda:"):
394-
if self._device != "cuda:0":
395-
logger.debug(
396-
f"SelectPrims only supports cuda:0. Using cuda:0 for SelectPrims "
397-
f"even though simulation device is {self._device}."
398-
)
399-
fabric_device = "cuda:0"
400398

401399
self._fabric_selection = fabric_stage.SelectPrims(
402400
require_attrs=[

0 commit comments

Comments
 (0)