diff --git a/cellpose/core.py b/cellpose/core.py index ecfde40c..b3bcb092 100644 --- a/cellpose/core.py +++ b/cellpose/core.py @@ -52,7 +52,14 @@ def _use_gpu_torch(gpu_number=0): except: pass try: - device = torch.device('mps:' + str(gpu_number)) + # ``gpu_number`` may be the string ``"mps"`` (passed straight through + # from ``--gpu_device mps``) rather than a device index, in which case + # ``"mps:mps"`` is not a valid device string. Fall back to the plain + # ``"mps"`` device for any non-integer gpu_number. + if str(gpu_number).isdigit(): + device = torch.device('mps:' + str(gpu_number)) + else: + device = torch.device('mps') _ = torch.zeros((1,1)).to(device) core_logger.info('** TORCH MPS version installed and working. **') return True diff --git a/tests/test_import.py b/tests/test_import.py index 2f4c709f..0e250cd6 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -21,6 +21,23 @@ def test_gpu_check(): core.use_gpu() +def test_assign_device_mps_string(): + # Regression test for #1455: passing ``--gpu_device mps`` (i.e. + # ``device="mps"``) must select the MPS backend, not silently fall back to + # CPU. Previously the literal string "mps" was used as a device index, + # building the invalid ``torch.device("mps:mps")`` which raised and forced + # CPU. Only meaningful where MPS is actually available. + import torch + from cellpose import core + + if not torch.backends.mps.is_available(): + pytest.skip("MPS backend not available") + + device, gpu = core.assign_device(use_torch=True, gpu=True, device="mps") + assert gpu + assert device.type == "mps" + + def itest_model_dir(): import os, pathlib import numpy as np