Skip to content

Commit 391209f

Browse files
committed
handle no torch import available case
1 parent b056a2a commit 391209f

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

src/transformers/utils/agnostic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
GPU calls that are device-agnostic.
1616
"""
1717

18-
import torch
18+
try:
19+
import torch
20+
except Exception:
21+
torch = None
1922

2023

2124
class AgnosticGPU:
2225
@staticmethod
2326
def configure() -> "AgnosticGPU":
2427
return (
25-
CUDAGPU()
28+
NoGPU() if torch is None
29+
else CUDAGPU()
2630
if torch.cuda.is_available()
2731
else XPUGPU()
2832
if (hasattr(torch, "xpu") and torch.xpu.is_available())
@@ -45,6 +49,7 @@ def device_count(self) -> int:
4549

4650
class CUDAGPU(AgnosticGPU):
4751
def __init__(self):
52+
assert torch is not None
4853
self.name = "cuda"
4954
self.is_accelerator_available = torch.cuda.is_available
5055
self.current_device = torch.cuda.current_device
@@ -53,6 +58,7 @@ def __init__(self):
5358

5459
class XPUGPU(AgnosticGPU):
5560
def __init__(self):
61+
assert torch is not None
5662
self.name = "xpu"
5763
self.is_accelerator_available = torch.xpu.is_available
5864
self.current_device = torch.xpu.current_device
@@ -61,6 +67,7 @@ def __init__(self):
6167

6268
class MPSGPU(AgnosticGPU):
6369
def __init__(self):
70+
assert torch is not None
6471
self.name = "mps"
6572
self.is_accelerator_available = torch.mps.is_available
6673
# self.current_device = torch.mps.current_device

0 commit comments

Comments
 (0)