From c1a72b633a80ded87e149761d15d6bad1d6bd8e9 Mon Sep 17 00:00:00 2001 From: francescomalandrino Date: Wed, 31 Jul 2024 14:08:45 +0200 Subject: [PATCH] Update base_trainer.py if mps is available and cuda is not, try to use "mps" as the device --- src/pythae/trainers/base_trainer/base_trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/pythae/trainers/base_trainer/base_trainer.py b/src/pythae/trainers/base_trainer/base_trainer.py index 073deede..606d5019 100644 --- a/src/pythae/trainers/base_trainer/base_trainer.py +++ b/src/pythae/trainers/base_trainer/base_trainer.py @@ -87,11 +87,12 @@ def __init__( device = self._setup_devices() else: - device = ( - "cuda" - if torch.cuda.is_available() and not self.training_config.no_cuda - else "cpu" - ) + if torch.cuda.is_available() and not self.training_config.no_cuda: + device = "cuda" + elif torch.backends.mps.is_available() and not self.training_config.no_cuda: + device = "mps" + else: + device = "cpu" self.amp_context = ( torch.autocast("cuda") @@ -174,8 +175,9 @@ def _setup_devices(self): device = "cpu" else: - torch.cuda.set_device(self.local_rank) - device = torch.device("cuda", self.local_rank) + if not device == "mps": + torch.cuda.set_device(self.local_rank) + device = torch.device("cuda", self.local_rank) if not dist.is_initialized(): dist.init_process_group(