diff --git a/hloc/extract_features.py b/hloc/extract_features.py index ab9456a8..ef587edf 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -256,6 +256,8 @@ def main( return feature_path device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu" and torch.backends.mps.is_available(): + device = "mps" Model = dynamic_load(extractors, conf["model"]["name"]) model = Model(conf["model"]).eval().to(device) diff --git a/hloc/match_features.py b/hloc/match_features.py index 679e81e9..e60e7175 100644 --- a/hloc/match_features.py +++ b/hloc/match_features.py @@ -234,6 +234,8 @@ def match_from_paths( return device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cpu" and torch.backends.mps.is_available(): + device = "mps" Model = dynamic_load(matchers, conf["model"]["name"]) model = Model(conf["model"]).eval().to(device)