-
Notifications
You must be signed in to change notification settings - Fork 115
Expand file tree
/
Copy pathcontext_impl.cc
More file actions
125 lines (95 loc) · 3.79 KB
/
context_impl.cc
File metadata and controls
125 lines (95 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include "context_impl.hpp"
#include "../utils.hpp"
namespace infinicore {
thread_local Runtime *ContextImpl::current_runtime_ = nullptr;
Runtime *ContextImpl::getCurrentRuntime() {
return current_runtime_;
}
Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get();
}
void ContextImpl::setDevice(Device device) {
if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set.
return;
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before.
runtime_table_[int(device.getType())][device.getIndex()] = std::unique_ptr<Runtime>(new Runtime(device));
current_runtime_ = runtime_table_[int(device.getType())][device.getIndex()].get();
} else {
current_runtime_ = runtime_table_[int(device.getType())][device.getIndex()].get()->activate();
}
}
size_t ContextImpl::getDeviceCount(Device::Type type) {
return runtime_table_[int(type)].size();
}
ContextImpl &ContextImpl::singleton() {
static ContextImpl instance;
return instance;
}
ContextImpl::ContextImpl() {
std::vector<int> device_counter(size_t(Device::Type::COUNT));
INFINICORE_CHECK_ERROR(infinirtGetAllDeviceCount(device_counter.data()));
// Reserve runtime slot for all devices.
runtime_table_[0].resize(device_counter[0]);
runtime_table_[0][0] = std::unique_ptr<Runtime>(new Runtime(Device(Device::Type::CPU, 0)));
// Context will try to use the first non-cpu available device as the default runtime.
for (int i = int(Device::Type::COUNT) - 1; i > 0; i--) {
if (device_counter[i] > 0) {
runtime_table_[i].resize(device_counter[i]);
if (current_runtime_ == nullptr) {
runtime_table_[i][0] = std::unique_ptr<Runtime>(new Runtime(Device(Device::Type(i), 0)));
current_runtime_ = runtime_table_[i][0].get();
}
}
}
if (current_runtime_ == nullptr) {
current_runtime_ = runtime_table_[0][0].get();
}
}
namespace context {
void setDevice(Device device) {
ContextImpl::singleton().setDevice(device);
}
Device getDevice() {
return ContextImpl::singleton().getCurrentRuntime()->device();
}
size_t getDeviceCount(Device::Type type) {
return ContextImpl::singleton().getDeviceCount(type);
}
infinirtStream_t getStream() {
return ContextImpl::singleton().getCurrentRuntime()->stream();
}
infiniopHandle_t getInfiniopHandle() {
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
}
void syncStream() {
return ContextImpl::singleton().getCurrentRuntime()->syncStream();
}
void syncDevice() {
return ContextImpl::singleton().getCurrentRuntime()->syncDevice();
}
std::shared_ptr<Memory> allocateMemory(size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->allocateMemory(size);
}
std::shared_ptr<Memory> allocateHostMemory(size_t size) {
return ContextImpl::singleton().getCpuRuntime()->allocateMemory(size);
}
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->allocatePinnedHostMemory(size);
}
void memcpyH2D(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyH2D(dst, src, size);
}
void memcpyD2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2H(dst, src, size);
}
void memcpyD2D(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size);
}
void memcpyH2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size);
}
} // namespace context
} // namespace infinicore