Skip to content

Commit cd69050

Browse files
fix: process_device dropping CUDA device index (#1860)
* Fix process_device dropping CUDA device index * Fix formatting via pre-commit * Fix formatting via ruff
1 parent e7fbec0 commit cd69050

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

sbi/utils/torchutils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def process_device(device: Union[str, torch.device]) -> str:
6262
)
6363
# Else, check whether the custom device is valid.
6464
else:
65-
check_device(device)
6665
if isinstance(device, torch.device):
67-
device = device.type
66+
device = str(device)
67+
68+
check_device(device)
6869

6970
return device
7071

0 commit comments

Comments
 (0)