|
23 | 23 |
|
24 | 24 | logger = logging.getLogger(__name__) |
25 | 25 |
|
| 26 | +# TODO: Currently we don't support multiple GPUs for Fabric-accelerated views |
| 27 | +# because USDRT SelectPrims only supported cuda:0 at the time of writing. |
| 28 | +_fabric_supported_devices = ("cpu", "cuda", "cuda:0") |
| 29 | + |
26 | 30 |
|
27 | 31 | def _to_float32_2d(a: wp.array | torch.Tensor) -> wp.array | torch.Tensor: |
28 | 32 | """Ensure array is compatible with Fabric kernels (2-D float32). |
@@ -75,10 +79,11 @@ def __init__( |
75 | 79 | settings = SettingsManager.instance() |
76 | 80 | self._use_fabric = bool(settings.get("/physics/fabricEnabled", False)) |
77 | 81 |
|
78 | | - if self._use_fabric and self._device not in ("cuda", "cuda:0"): |
| 82 | + if self._use_fabric and self._device not in _fabric_supported_devices: |
79 | 83 | logger.warning( |
80 | 84 | f"Fabric mode is not supported on device '{self._device}'. " |
81 | | - "USDRT SelectPrims and Warp fabric arrays only support cuda:0. " |
| 85 | + "USDRT SelectPrims and Warp fabric arrays are currently " |
| 86 | + f"only supported on {', '.join(_fabric_supported_devices)}. " |
82 | 87 | "Falling back to standard USD operations. This may impact performance." |
83 | 88 | ) |
84 | 89 | self._use_fabric = False |
@@ -386,17 +391,10 @@ def _initialize_fabric(self) -> None: |
386 | 391 | ) |
387 | 392 | wp.synchronize() |
388 | 393 |
|
| 394 | + # The constructor should have taken care of this, but double check here to avoid regressions |
| 395 | + assert self._device in _fabric_supported_devices |
| 396 | + |
389 | 397 | fabric_device = self._device |
390 | | - if self._device == "cuda": |
391 | | - logger.warning("Fabric device is not specified, defaulting to 'cuda:0'.") |
392 | | - fabric_device = "cuda:0" |
393 | | - elif self._device.startswith("cuda:"): |
394 | | - if self._device != "cuda:0": |
395 | | - logger.debug( |
396 | | - f"SelectPrims only supports cuda:0. Using cuda:0 for SelectPrims " |
397 | | - f"even though simulation device is {self._device}." |
398 | | - ) |
399 | | - fabric_device = "cuda:0" |
400 | 398 |
|
401 | 399 | self._fabric_selection = fabric_stage.SelectPrims( |
402 | 400 | require_attrs=[ |
|
0 commit comments