File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1515GPU calls that are device-agnostic.
1616"""
1717
18- import torch
18+ try :
19+ import torch
20+ except Exception :
21+ torch = None
1922
2023
2124class AgnosticGPU :
2225 @staticmethod
2326 def configure () -> "AgnosticGPU" :
2427 return (
25- CUDAGPU ()
28+ NoGPU () if torch is None
29+ else CUDAGPU ()
2630 if torch .cuda .is_available ()
2731 else XPUGPU ()
2832 if (hasattr (torch , "xpu" ) and torch .xpu .is_available ())
@@ -45,6 +49,7 @@ def device_count(self) -> int:
4549
4650class CUDAGPU (AgnosticGPU ):
4751 def __init__ (self ):
52+ assert torch is not None
4853 self .name = "cuda"
4954 self .is_accelerator_available = torch .cuda .is_available
5055 self .current_device = torch .cuda .current_device
@@ -53,6 +58,7 @@ def __init__(self):
5358
5459class XPUGPU (AgnosticGPU ):
5560 def __init__ (self ):
61+ assert torch is not None
5662 self .name = "xpu"
5763 self .is_accelerator_available = torch .xpu .is_available
5864 self .current_device = torch .xpu .current_device
@@ -61,6 +67,7 @@ def __init__(self):
6167
6268class MPSGPU (AgnosticGPU ):
6369 def __init__ (self ):
70+ assert torch is not None
6471 self .name = "mps"
6572 self .is_accelerator_available = torch .mps .is_available
6673 # self.current_device = torch.mps.current_device
You can’t perform that action at this time.
0 commit comments