11#include " graph_manager.hpp"
22
33#include " ../utils.hpp"
4+ #include " infinicore/context/context.hpp"
5+ #include < infinirt.h>
46
57namespace infinicore ::graph {
68
@@ -33,16 +35,91 @@ GraphOperator::~GraphOperator() {
3335 * Graph
3436 * ========================= */
3537
38+ struct Graph ::DeviceGraph {
39+ infinirtGraph_t graph;
40+ infinirtGraphExec_t exec;
41+ infinirtGraphNode_t node;
42+ std::vector<char > log_buffer;
43+
44+ DeviceGraph () {
45+ log_buffer.resize (4 * 1024 );
46+ }
47+
48+ ~DeviceGraph () {
49+ if (exec) {
50+ infinirtGraphExecDestroy (exec);
51+ }
52+ if (graph) {
53+ infinirtGraphDestroy (graph);
54+ }
55+ }
56+
57+ void launch () {
58+ INFINICORE_CHECK_ERROR (infinirtGraphLuanch (exec, context::getStream ()));
59+ }
60+ };
61+
62+ Graph::Graph () {
63+ }
64+
3665void Graph::run () const {
37- for (auto &op : op_list_) {
38- op->run ();
66+ if (device_graph_ != nullptr && device_graph_.get ()->exec != nullptr ) {
67+ device_graph_.get ()->launch ();
68+ } else {
69+ for (auto &op : op_list_) {
70+ op->run ();
71+ }
3972 }
4073}
4174
4275void Graph::add_operator (std::shared_ptr<GraphOperator> op) {
4376 op_list_.push_back (op);
4477}
4578
79+ void Graph::instantiate () {
80+ // Reset device graph
81+ device_graph_ = std::make_unique<DeviceGraph>();
82+
83+ // warmup
84+ for (size_t iter = 0 ; iter < 5 ; ++iter) {
85+ this ->run ();
86+ }
87+ infinicore::context::syncStream ();
88+
89+ if (infinirtStreamBeginCapture (
90+ context::getStream (),
91+ INFINIRT_STREAM_CAPTURE_MODE_GLOBAL)
92+ != INFINI_STATUS_SUCCESS) {
93+ return ;
94+ }
95+
96+ // Run and record
97+ this ->run ();
98+
99+ if (infinirtStreamEndCapture (
100+ context::getStream (),
101+ &device_graph_.get ()->graph )
102+ != INFINI_STATUS_SUCCESS) {
103+ return ;
104+ }
105+
106+ if (infinirtGraphInstantiate (
107+ &device_graph_.get ()->exec ,
108+ device_graph_.get ()->graph ,
109+ &device_graph_.get ()->node ,
110+ device_graph_.get ()->log_buffer .data (),
111+ device_graph_.get ()->log_buffer .size ())
112+ != INFINI_STATUS_SUCCESS) {
113+ static bool warned_once = false ;
114+ if (!warned_once) {
115+ warned_once = true ;
116+ spdlog::warn (" Fail to instantiate device graph: {}" , std::string (device_graph_.get ()->log_buffer .data ()));
117+ }
118+ }
119+ }
120+
121+ Graph::~Graph () = default ;
122+
46123/* =========================
47124 * GraphManager
48125 * ========================= */
@@ -52,19 +129,26 @@ bool GraphManager::is_recording() const {
52129}
53130
54131void GraphManager::start_recording () {
132+ if (is_recording ()) {
133+ spdlog::warn (" Graph is already recording. Previous recording will be dropped." );
134+ }
55135 recording_ = true ;
56136 graph_ = std::make_shared<Graph>();
57137}
58138
59139void GraphManager::add_operator (std::shared_ptr<GraphOperator> op) {
60- INFINICORE_ASSERT (recording_ );
140+ INFINICORE_ASSERT (is_recording () );
61141
62142 graph_->add_operator (op);
63143}
64144
65145std::shared_ptr<Graph> GraphManager::stop_recording () {
66-
146+ if (!is_recording ()) {
147+ spdlog::warn (" Graph is not recording. Please start recording first." );
148+ return nullptr ;
149+ }
67150 recording_ = false ;
151+ graph_->instantiate ();
68152 return std::exchange (graph_, nullptr );
69153}
70154
0 commit comments