-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathsettings.py
More file actions
35 lines (26 loc) · 1.06 KB
/
settings.py
File metadata and controls
35 lines (26 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
import torch
_POSSIBLE_TEST_DEVICES = {"cpu", "cuda:0"}
_POSSIBLE_TEST_DTYPES = {"float32", "float64"}
try:
_device_str = os.environ["PYTEST_TORCH_DEVICE"]
except KeyError:
_device_str = "cpu" # Default to cpu if environment variable not set
if _device_str not in _POSSIBLE_TEST_DEVICES:
raise ValueError(
f"Invalid value of environment variable PYTEST_TORCH_DEVICE: {_device_str}.\n"
f"Possible values: {_POSSIBLE_TEST_DEVICES}."
)
if _device_str == "cuda:0" and not torch.cuda.is_available():
raise ValueError('Requested device "cuda:0" but cuda is not available.')
DEVICE = torch.device(_device_str)
try:
_dtype_str = os.environ["PYTEST_TORCH_DTYPE"]
except KeyError:
_dtype_str = "float32" # Default to float32 if environment variable not set
if _dtype_str not in _POSSIBLE_TEST_DTYPES:
raise ValueError(
f"Invalid value of environment variable PYTEST_TORCH_DTYPE: {_dtype_str}.\n"
f"Possible values: {_POSSIBLE_TEST_DTYPES}."
)
DTYPE = getattr(torch, _dtype_str) # "float32" => torch.float32