Skip to content

Commit f00c06d

Browse files
Merge pull request #955 from InfiniTensor/issue/811
issue/811 support cuda graph capture
2 parents 148b475 + 3a8c686 commit f00c06d

File tree

14 files changed

+392
-8
lines changed

14 files changed

+392
-8
lines changed

include/infinicore/graph/graph.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,21 @@ class GraphOperator {
3131

3232
class Graph {
3333
public:
34-
Graph() = default;
35-
~Graph() = default;
34+
Graph();
35+
~Graph();
3636

3737
void run() const;
3838

3939
protected:
4040
void add_operator(std::shared_ptr<GraphOperator> op);
41-
41+
void instantiate();
4242
std::vector<std::shared_ptr<GraphOperator>> op_list_;
4343

4444
friend class GraphManager;
45+
46+
private:
47+
struct DeviceGraph;
48+
std::unique_ptr<DeviceGraph> device_graph_;
4549
};
4650
} // namespace infinicore::graph
4751

include/infinirt.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
typedef void *infinirtStream_t;
88
typedef void *infinirtEvent_t;
9+
typedef void *infinirtGraph_t;
10+
typedef void *infinirtGraphNode_t;
11+
typedef void *infinirtGraphExec_t;
912

1013
__C __export infiniStatus_t infinirtInit();
1114

@@ -63,4 +66,24 @@ __C __export infiniStatus_t infinirtMemcpyAsync(void *dst, const void *src, size
6366
__C __export infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream_t stream);
6467
__C __export infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream);
6568

69+
// Graph
70+
typedef enum {
71+
INFINIRT_STREAM_CAPTURE_MODE_GLOBAL = 0,
72+
INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1,
73+
INFINIRT_STREAM_CAPTURE_MODE_RELAXED = 2,
74+
75+
} infinirtStreamCaptureMode_t;
76+
77+
__C __export infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode);
78+
__C __export infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr);
79+
__C __export infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph);
80+
__C __export infiniStatus_t infinirtGraphInstantiate(
81+
infinirtGraphExec_t *graph_exec_ptr,
82+
infinirtGraph_t graph,
83+
infinirtGraphNode_t *node_ptr,
84+
char *log_buffer,
85+
size_t buffer_size);
86+
__C __export infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec);
87+
__C __export infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream);
88+
6689
#endif // __INFINIRT_API_H__

src/infinicore/graph/graph.cc

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "graph_manager.hpp"
22

33
#include "../utils.hpp"
4+
#include "infinicore/context/context.hpp"
5+
#include <infinirt.h>
46

57
namespace infinicore::graph {
68

@@ -33,16 +35,91 @@ GraphOperator::~GraphOperator() {
3335
* Graph
3436
* ========================= */
3537

38+
struct Graph::DeviceGraph {
39+
infinirtGraph_t graph;
40+
infinirtGraphExec_t exec;
41+
infinirtGraphNode_t node;
42+
std::vector<char> log_buffer;
43+
44+
DeviceGraph() {
45+
log_buffer.resize(4 * 1024);
46+
}
47+
48+
~DeviceGraph() {
49+
if (exec) {
50+
infinirtGraphExecDestroy(exec);
51+
}
52+
if (graph) {
53+
infinirtGraphDestroy(graph);
54+
}
55+
}
56+
57+
void launch() {
58+
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream()));
59+
}
60+
};
61+
62+
Graph::Graph() {
63+
}
64+
3665
void Graph::run() const {
37-
for (auto &op : op_list_) {
38-
op->run();
66+
if (device_graph_ != nullptr && device_graph_.get()->exec != nullptr) {
67+
device_graph_.get()->launch();
68+
} else {
69+
for (auto &op : op_list_) {
70+
op->run();
71+
}
3972
}
4073
}
4174

4275
void Graph::add_operator(std::shared_ptr<GraphOperator> op) {
4376
op_list_.push_back(op);
4477
}
4578

79+
void Graph::instantiate() {
80+
// Reset device graph
81+
device_graph_ = std::make_unique<DeviceGraph>();
82+
83+
// warmup
84+
for (size_t iter = 0; iter < 5; ++iter) {
85+
this->run();
86+
}
87+
infinicore::context::syncStream();
88+
89+
if (infinirtStreamBeginCapture(
90+
context::getStream(),
91+
INFINIRT_STREAM_CAPTURE_MODE_GLOBAL)
92+
!= INFINI_STATUS_SUCCESS) {
93+
return;
94+
}
95+
96+
// Run and record
97+
this->run();
98+
99+
if (infinirtStreamEndCapture(
100+
context::getStream(),
101+
&device_graph_.get()->graph)
102+
!= INFINI_STATUS_SUCCESS) {
103+
return;
104+
}
105+
106+
if (infinirtGraphInstantiate(
107+
&device_graph_.get()->exec,
108+
device_graph_.get()->graph,
109+
&device_graph_.get()->node,
110+
device_graph_.get()->log_buffer.data(),
111+
device_graph_.get()->log_buffer.size())
112+
!= INFINI_STATUS_SUCCESS) {
113+
static bool warned_once = false;
114+
if (!warned_once) {
115+
warned_once = true;
116+
spdlog::warn("Fail to instantiate device graph: {}", std::string(device_graph_.get()->log_buffer.data()));
117+
}
118+
}
119+
}
120+
121+
Graph::~Graph() = default;
122+
46123
/* =========================
47124
* GraphManager
48125
* ========================= */
@@ -52,19 +129,26 @@ bool GraphManager::is_recording() const {
52129
}
53130

54131
void GraphManager::start_recording() {
132+
if (is_recording()) {
133+
spdlog::warn("Graph is already recording. Previous recording will be dropped.");
134+
}
55135
recording_ = true;
56136
graph_ = std::make_shared<Graph>();
57137
}
58138

59139
void GraphManager::add_operator(std::shared_ptr<GraphOperator> op) {
60-
INFINICORE_ASSERT(recording_);
140+
INFINICORE_ASSERT(is_recording());
61141

62142
graph_->add_operator(op);
63143
}
64144

65145
std::shared_ptr<Graph> GraphManager::stop_recording() {
66-
146+
if (!is_recording()) {
147+
spdlog::warn("Graph is not recording. Please start recording first.");
148+
return nullptr;
149+
}
67150
recording_ = false;
151+
graph_->instantiate();
68152
return std::exchange(graph_, nullptr);
69153
}
70154

src/infiniop/devices/nvidia/nvidia_common.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ Handle::Internal::Internal(int device_id) {
2323
_grid_size[0] = prop.maxGridSize[0];
2424
_grid_size[1] = prop.maxGridSize[1];
2525
_grid_size[2] = prop.maxGridSize[2];
26+
this->useCublas(nullptr, [](cublasHandle_t handle) { return INFINI_STATUS_SUCCESS; });
27+
#ifdef ENABLE_CUDNN_API
28+
this->useCudnn(nullptr, [](cudnnHandle_t handle) { return INFINI_STATUS_SUCCESS; });
29+
#endif
2630
}
2731

2832
infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const {

src/infinirt/ascend/infinirt_ascend.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,35 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
150150
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
151151
return freeDevice(ptr);
152152
}
153+
154+
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
155+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
156+
}
157+
158+
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
159+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
160+
}
161+
162+
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
163+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
164+
}
165+
166+
infiniStatus_t graphInstantiate(
167+
infinirtGraphExec_t *graph_exec_ptr,
168+
infinirtGraph_t graph,
169+
infinirtGraphNode_t *node_ptr,
170+
char *log_buffer,
171+
size_t buffer_size) {
172+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
173+
}
174+
175+
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
176+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
177+
}
178+
179+
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
180+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
181+
}
182+
153183
} // namespace infinirt::ascend
154184
#undef CHECK_ACLRT

src/infinirt/bang/infinirt_bang.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,34 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
142142
CHECK_BANGRT(cnrtFree(ptr));
143143
return INFINI_STATUS_SUCCESS;
144144
}
145+
146+
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
147+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
148+
}
149+
150+
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
151+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
152+
}
153+
154+
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
155+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
156+
}
157+
158+
infiniStatus_t graphInstantiate(
159+
infinirtGraphExec_t *graph_exec_ptr,
160+
infinirtGraph_t graph,
161+
infinirtGraphNode_t *node_ptr,
162+
char *log_buffer,
163+
size_t buffer_size) {
164+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
165+
}
166+
167+
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
168+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
169+
}
170+
171+
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
172+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
173+
}
174+
145175
} // namespace infinirt::bang

src/infinirt/cpu/infinirt_cpu.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,33 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
116116
return freeDevice(ptr);
117117
}
118118

119+
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
120+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
121+
}
122+
123+
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
124+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
125+
}
126+
127+
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
128+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
129+
}
130+
131+
infiniStatus_t graphInstantiate(
132+
infinirtGraphExec_t *graph_exec_ptr,
133+
infinirtGraph_t graph,
134+
infinirtGraphNode_t *node_ptr,
135+
char *log_buffer,
136+
size_t buffer_size) {
137+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
138+
}
139+
140+
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
141+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
142+
}
143+
144+
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
145+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
146+
}
147+
119148
} // namespace infinirt::cpu

src/infinirt/cuda/infinirt_cuda.cu

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,53 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
176176
RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
177177
return INFINI_STATUS_SUCCESS;
178178
}
179+
180+
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
181+
cudaStreamCaptureMode graph_mode;
182+
if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) {
183+
graph_mode = cudaStreamCaptureModeGlobal;
184+
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) {
185+
graph_mode = cudaStreamCaptureModeThreadLocal;
186+
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) {
187+
graph_mode = cudaStreamCaptureModeRelaxed;
188+
} else {
189+
return INFINI_STATUS_BAD_PARAM;
190+
}
191+
192+
CHECK_CUDART(cudaStreamBeginCapture((cudaStream_t)stream, graph_mode));
193+
194+
return INFINI_STATUS_SUCCESS;
195+
}
196+
197+
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
198+
cudaGraph_t graph;
199+
CHECK_CUDART(cudaStreamEndCapture((cudaStream_t)stream, &graph));
200+
*graph_ptr = graph;
201+
return INFINI_STATUS_SUCCESS;
202+
}
203+
204+
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
205+
RUN_CUDART(cudaGraphDestroy((cudaGraph_t)graph));
206+
return INFINI_STATUS_SUCCESS;
207+
}
208+
209+
infiniStatus_t graphInstantiate(
210+
infinirtGraphExec_t *graph_exec_ptr,
211+
infinirtGraph_t graph,
212+
infinirtGraphNode_t *node_ptr,
213+
char *log_buffer,
214+
size_t buffer_size) {
215+
CHECK_CUDART(cudaGraphInstantiate((cudaGraphExec_t *)graph_exec_ptr, (cudaGraph_t)graph, (cudaGraphNode_t *)node_ptr, log_buffer, buffer_size));
216+
return INFINI_STATUS_SUCCESS;
217+
}
218+
219+
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
220+
RUN_CUDART(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
221+
return INFINI_STATUS_SUCCESS;
222+
}
223+
224+
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
225+
CHECK_CUDART(cudaGraphLaunch((cudaGraphExec_t)graph_exec, (cudaStream_t)stream));
226+
return INFINI_STATUS_SUCCESS;
227+
}
179228
}

src/infinirt/infinirt.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,32 @@ __C infiniStatus_t infinirtMallocAsync(void **p_ptr, size_t size, infinirtStream
192192
__C infiniStatus_t infinirtFreeAsync(void *ptr, infinirtStream_t stream) {
193193
INFINIRT_CALL_DEVICE_API(freeAsync, (ptr, stream));
194194
}
195+
196+
__C infiniStatus_t infinirtStreamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
197+
INFINIRT_CALL_DEVICE_API(streamBeginCapture, (stream, mode));
198+
}
199+
200+
__C infiniStatus_t infinirtStreamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
201+
INFINIRT_CALL_DEVICE_API(streamEndCapture, (stream, graph_ptr));
202+
}
203+
204+
__C infiniStatus_t infinirtGraphDestroy(infinirtGraph_t graph) {
205+
INFINIRT_CALL_DEVICE_API(graphDestroy, (graph));
206+
}
207+
208+
__C infiniStatus_t infinirtGraphInstantiate(
209+
infinirtGraphExec_t *graph_exec_ptr,
210+
infinirtGraph_t graph,
211+
infinirtGraphNode_t *node_ptr,
212+
char *log_buffer,
213+
size_t buffer_size) {
214+
INFINIRT_CALL_DEVICE_API(graphInstantiate, (graph_exec_ptr, graph, node_ptr, log_buffer, buffer_size));
215+
}
216+
217+
__C infiniStatus_t infinirtGraphExecDestroy(infinirtGraphExec_t graph_exec) {
218+
INFINIRT_CALL_DEVICE_API(graphExecDestroy, (graph_exec));
219+
}
220+
221+
__C infiniStatus_t infinirtGraphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
222+
INFINIRT_CALL_DEVICE_API(graphLuanch, (graph_exec, stream));
223+
}

0 commit comments

Comments
 (0)