-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathdevice.py
More file actions
97 lines (72 loc) · 2.93 KB
/
Copy pathdevice.py
File metadata and controls
97 lines (72 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from infinicore.lib import _infinicore
class device:
def __init__(self, type=None, index=None):
if type is None:
type = "cpu"
if isinstance(type, device):
self.type = type.type
self.index = type.index
return
if ":" in type:
if index is not None:
raise ValueError(
'`index` should not be provided when `type` contains `":"`.'
)
type, index = type.split(":")
index = int(index)
self.type = type
self.index = index
_type, _index = device._to_infinicore_device(type, index if index else 0)
self._underlying = _infinicore.Device(_type, _index)
def __repr__(self):
return f"device(type='{self.type}'{f', index={self.index}' if self.index is not None else ''})"
def __str__(self):
return f"{self.type}{f':{self.index}' if self.index is not None else ''}"
@staticmethod
def _to_infinicore_device(type, index):
all_device_types = tuple(_infinicore.Device.Type.__members__.values())[:-1]
all_device_count = tuple(
_infinicore.get_device_count(device) for device in all_device_types
)
torch_devices = {
torch_type: {
infinicore_type: 0
for infinicore_type in all_device_types
if _TORCH_DEVICE_MAP[infinicore_type] == torch_type
}
for torch_type in _TORCH_DEVICE_MAP.values()
}
for i, count in enumerate(all_device_count):
infinicore_device_type = _infinicore.Device.Type(i)
torch_devices[_TORCH_DEVICE_MAP[infinicore_device_type]][
infinicore_device_type
] += count
for infinicore_device_type, infinicore_device_count in torch_devices[
type
].items():
for i in range(infinicore_device_count):
if index == 0:
return infinicore_device_type, i
index -= 1
@staticmethod
def _from_infinicore_device(infinicore_device):
type = _TORCH_DEVICE_MAP[infinicore_device.type]
base_index = 0
for infinicore_type, torch_type in _TORCH_DEVICE_MAP.items():
if torch_type != type:
continue
if infinicore_type == infinicore_device.type:
break
base_index += _infinicore.get_device_count(infinicore_device)
return device(type, base_index + infinicore_device.index)
_TORCH_DEVICE_MAP = {
_infinicore.Device.Type.CPU: "cpu",
_infinicore.Device.Type.NVIDIA: "cuda",
_infinicore.Device.Type.CAMBRICON: "mlu",
_infinicore.Device.Type.ASCEND: "npu",
_infinicore.Device.Type.METAX: "cuda",
_infinicore.Device.Type.MOORE: "musa",
_infinicore.Device.Type.ILUVATAR: "cuda",
_infinicore.Device.Type.KUNLUN: "cuda",
_infinicore.Device.Type.HYGON: "cuda",
}