11#include " graph_manager.hpp"
22
33#include " ../utils.hpp"
4+ #include " infinicore/context/context.hpp"
5+ #include < infinirt.h>
46
57namespace infinicore ::graph {
68
@@ -33,16 +35,85 @@ GraphOperator::~GraphOperator() {
3335 * Graph
3436 * ========================= */
3537
38+ struct Graph ::DeviceGraph {
39+ infinirtGraph_t graph;
40+ infinirtGraphExec_t exec;
41+ infinirtGraphNode_t node;
42+ std::vector<char > log_buffer;
43+
44+ DeviceGraph () {
45+ log_buffer.resize (4 * 1024 );
46+ }
47+
48+ ~DeviceGraph () {
49+ if (exec) {
50+ infinirtGraphExecDestroy (exec);
51+ }
52+ if (graph) {
53+ infinirtGraphDestroy (graph);
54+ }
55+ }
56+
57+ void launch () {
58+ INFINICORE_CHECK_ERROR (infinirtGraphLuanch (exec, context::getStream ()));
59+ }
60+ };
61+
62+ Graph::Graph () {
63+ }
64+
3665void Graph::run () const {
37- for (auto &op : op_list_) {
38- op->run ();
66+ if (device_graph_ != nullptr && device_graph_.get ()->exec != nullptr ) {
67+ device_graph_.get ()->launch ();
68+ } else {
69+ for (auto &op : op_list_) {
70+ op->run ();
71+ }
3972 }
4073}
4174
4275void Graph::add_operator (std::shared_ptr<GraphOperator> op) {
4376 op_list_.push_back (op);
4477}
4578
79+ void Graph::instantiate () {
80+ // Reset device graph
81+ device_graph_ = std::make_unique<DeviceGraph>();
82+
83+ if (infinirtStreamBeginCapture (
84+ context::getStream (),
85+ INFINIRT_STREAM_CAPTURE_MODE_GLOBAL)
86+ != INFINI_STATUS_SUCCESS) {
87+ return ;
88+ }
89+
90+ // Run and record
91+ this ->run ();
92+
93+ if (infinirtStreamEndCapture (
94+ context::getStream (),
95+ &device_graph_.get ()->graph )
96+ != INFINI_STATUS_SUCCESS) {
97+ return ;
98+ }
99+
100+ if (infinirtGraphInstantiate (
101+ &device_graph_.get ()->exec ,
102+ device_graph_.get ()->graph ,
103+ &device_graph_.get ()->node ,
104+ device_graph_.get ()->log_buffer .data (),
105+ device_graph_.get ()->log_buffer .size ())
106+ != INFINI_STATUS_SUCCESS) {
107+ static bool warned_once = false ;
108+ if (!warned_once) {
109+ warned_once = true ;
110+ spdlog::warn (" Fail to instantiate device graph: {}" , std::string (device_graph_.get ()->log_buffer .data ()));
111+ }
112+ }
113+ }
114+
115+ Graph::~Graph () = default ;
116+
46117/* =========================
47118 * GraphManager
48119 * ========================= */
@@ -52,19 +123,26 @@ bool GraphManager::is_recording() const {
52123}
53124
54125void GraphManager::start_recording () {
126+ if (is_recording ()) {
127+ spdlog::warn (" Graph is already recording. Previous recording will be dropped." );
128+ }
55129 recording_ = true ;
56130 graph_ = std::make_shared<Graph>();
57131}
58132
59133void GraphManager::add_operator (std::shared_ptr<GraphOperator> op) {
60- INFINICORE_ASSERT (recording_ );
134+ INFINICORE_ASSERT (is_recording () );
61135
62136 graph_->add_operator (op);
63137}
64138
65139std::shared_ptr<Graph> GraphManager::stop_recording () {
66-
140+ if (!is_recording ()) {
141+ spdlog::warn (" Graph is not recording. Please start recording first." );
142+ return nullptr ;
143+ }
67144 recording_ = false ;
145+ graph_->instantiate ();
68146 return std::exchange (graph_, nullptr );
69147}
70148
0 commit comments