Skip to content

Commit 93e7d88

Browse files
authored
issue/492: 修复 infinicore.Tensor.dtype 和 infinicore.Tensor.device 返回类型的问题
1 parent 9a05446 commit 93e7d88

8 files changed

Lines changed: 41 additions & 12 deletions

File tree

include/infinicore/device.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Device {
2020
MOORE = INFINI_DEVICE_MOORE,
2121
ILUVATAR = INFINI_DEVICE_ILUVATAR,
2222
KUNLUN = INFINI_DEVICE_KUNLUN,
23-
SUGON = INFINI_DEVICE_SUGON,
23+
HYGON = INFINI_DEVICE_HYGON,
2424
COUNT = INFINI_DEVICE_TYPE_COUNT,
2525
};
2626

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
complex64,
1010
complex128,
1111
double,
12+
dtype,
1213
float,
1314
float16,
1415
float32,
@@ -37,6 +38,7 @@
3738
__all__ = [
3839
# Classes.
3940
"device",
41+
"dtype",
4042
# Data Types.
4143
"bfloat16",
4244
"bool",

python/infinicore/device.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,23 @@ def _to_infinicore_device(type, index):
6666

6767
index -= 1
6868

69+
@staticmethod
70+
def _from_infinicore_device(infinicore_device):
71+
type = _TORCH_DEVICE_MAP[infinicore_device.type]
72+
73+
base_index = 0
74+
75+
for infinicore_type, torch_type in _TORCH_DEVICE_MAP.items():
76+
if torch_type != type:
77+
continue
78+
79+
if infinicore_type == infinicore_device.type:
80+
break
81+
82+
base_index += _infinicore.get_device_count(infinicore_device)
83+
84+
return device(type, base_index + infinicore_device.index)
85+
6986

7087
_TORCH_DEVICE_MAP = {
7188
_infinicore.Device.Type.CPU: "cpu",
@@ -76,5 +93,5 @@ def _to_infinicore_device(type, index):
7693
_infinicore.Device.Type.MOORE: "musa",
7794
_infinicore.Device.Type.ILUVATAR: "cuda",
7895
_infinicore.Device.Type.KUNLUN: "cuda",
79-
_infinicore.Device.Type.SUGON: "cuda",
96+
_infinicore.Device.Type.HYGON: "cuda",
8097
}

python/infinicore/tensor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
1+
import infinicore.device
2+
import infinicore.dtype
3+
14
from . import _infinicore
25

36

47
class Tensor:
5-
def __init__(self, tensor):
8+
def __init__(self, underlying):
69
"""An internal method. Please do not use this directly."""
710

8-
self._underlying = tensor
11+
self._underlying = underlying
12+
13+
self._dtype = infinicore.dtype(self._underlying.dtype)
14+
15+
self._device = infinicore.device._from_infinicore_device(
16+
self._underlying.device
17+
)
918

1019
@property
1120
def shape(self):
1221
return self._underlying.shape
1322

1423
@property
1524
def dtype(self):
16-
return self._underlying.dtype
25+
return self._dtype
1726

1827
@property
1928
def device(self):
20-
return self._underlying.device
29+
return self._device
2130

2231
@property
2332
def ndim(self):

src/infinicore/device.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ std::string Device::toString(const Type &type) {
3737
return "ILUVATAR";
3838
case Type::KUNLUN:
3939
return "KUNLUN";
40-
case Type::SUGON:
41-
return "SUGON";
40+
case Type::HYGON:
41+
return "HYGON";
4242
}
4343

4444
// TODO: Add error handling.

src/infinicore/pybind11/device.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ inline void bind(py::module &m) {
2020
.value("MOORE", Device::Type::MOORE)
2121
.value("ILUVATAR", Device::Type::ILUVATAR)
2222
.value("KUNLUN", Device::Type::KUNLUN)
23-
.value("SUGON", Device::Type::SUGON)
23+
.value("HYGON", Device::Type::HYGON)
2424
.value("COUNT", Device::Type::COUNT);
2525

2626
device

src/infinicore/pybind11/tensor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ inline void bind(py::module &m) {
1515
.def_property_readonly("strides", [](const Tensor &tensor) { return tensor->strides(); })
1616
.def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); })
1717
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
18+
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
1819

1920
.def("data_ptr", [](const Tensor &tensor) { return tensor->data(); })
2021
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })

test/infinicore/framework/devices.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class InfiniDeviceEnum:
77
MOORE = 5
88
ILUVATAR = 6
99
KUNLUN = 7
10-
SUGON = 8
10+
HYGON = 8
1111

1212

1313
InfiniDeviceNames = {
@@ -19,7 +19,7 @@ class InfiniDeviceEnum:
1919
InfiniDeviceEnum.MOORE: "Moore",
2020
InfiniDeviceEnum.ILUVATAR: "Iluvatar",
2121
InfiniDeviceEnum.KUNLUN: "Kunlun",
22-
InfiniDeviceEnum.SUGON: "Sugon",
22+
InfiniDeviceEnum.HYGON: "Hygon",
2323
}
2424

2525
# Mapping that maps InfiniDeviceEnum to torch device string
@@ -32,5 +32,5 @@ class InfiniDeviceEnum:
3232
InfiniDeviceEnum.MOORE: "musa",
3333
InfiniDeviceEnum.ILUVATAR: "cuda",
3434
InfiniDeviceEnum.KUNLUN: "cuda",
35-
InfiniDeviceEnum.SUGON: "cuda",
35+
InfiniDeviceEnum.HYGON: "cuda",
3636
}

0 commit comments

Comments
 (0)