-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy path_utils.py
More file actions
54 lines (39 loc) · 1.35 KB
/
_utils.py
File metadata and controls
54 lines (39 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Utils."""
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING, TypeVar
import torch
if TYPE_CHECKING:
from types import ModuleType
T = TypeVar("T")
def _funcs_to_dict(*funcs: T) -> dict[str, T]:
"""Convert functions to a dictionary.
Args:
*funcs: Functions to convert
Returns:
Dictionary of functions
"""
return {func.__name__: func for func in funcs} # type: ignore[attr-defined]
def detect_device() -> str:
"""Automatically detects CUDA, MPS and CPU."""
if torch.cuda.is_available():
return "cuda"
if torch.mps.is_available():
return "mps"
return "cpu"
def require(dependency: str, extra: str | None = None) -> ModuleType:
"""Try to import dependency, raise informative ImportError if missing.
Args:
dependency: The name of the module to import
extra: Optional extra package name for pip install instructions
Returns:
The imported module
Raises:
ImportError: If the dependency is not installed
"""
try:
return importlib.import_module(dependency)
except ImportError as e:
extra_info = f" Install with `pip install autointent[{extra}]`." if extra else ""
msg = f"Missing dependency '{dependency}' required for this feature.{extra_info}"
raise ImportError(msg) from e