Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cellpose/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@ def _use_gpu_torch(gpu_number=0):
except:
pass
try:
device = torch.device('mps:' + str(gpu_number))
# ``gpu_number`` may be the string ``"mps"`` (passed straight through
# from ``--gpu_device mps``) rather than a device index, in which case
# ``"mps:mps"`` is not a valid device string. Fall back to the plain
# ``"mps"`` device for any non-integer gpu_number.
if str(gpu_number).isdigit():
device = torch.device('mps:' + str(gpu_number))
else:
device = torch.device('mps')
_ = torch.zeros((1,1)).to(device)
core_logger.info('** TORCH MPS version installed and working. **')
return True
Expand Down
17 changes: 17 additions & 0 deletions tests/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ def test_gpu_check():
core.use_gpu()


def test_assign_device_mps_string():
# Regression test for #1455: passing ``--gpu_device mps`` (i.e.
# ``device="mps"``) must select the MPS backend, not silently fall back to
# CPU. Previously the literal string "mps" was used as a device index,
# building the invalid ``torch.device("mps:mps")`` which raised and forced
# CPU. Only meaningful where MPS is actually available.
import torch
from cellpose import core

if not torch.backends.mps.is_available():
pytest.skip("MPS backend not available")

device, gpu = core.assign_device(use_torch=True, gpu=True, device="mps")
assert gpu
assert device.type == "mps"


def itest_model_dir():
import os, pathlib
import numpy as np
Expand Down