11#include " graph_manager.hpp"
22
33#include " ../utils.hpp"
4+ #include " infinicore/context/context.hpp"
5+ #include < infinirt.h>
46
57namespace infinicore ::graph {
68
79/* =========================
810 * GraphTensor
911 * ========================= */
1012
11- GraphTensor::GraphTensor (const Tensor &tensor) : Tensor(tensor->to_blob ()) {
12- }
13-
14- void GraphTensor::resume () const {
15- resume_from_blob_ ();
13+ GraphTensor::GraphTensor (const Tensor &tensor) : Tensor(tensor->to_blob_ ()) {
1614}
1715
1816/* =========================
@@ -33,16 +31,91 @@ GraphOperator::~GraphOperator() {
3331 * Graph
3432 * ========================= */
3533
34+ struct Graph ::DeviceGraph {
35+ infinirtGraph_t graph;
36+ infinirtGraphExec_t exec;
37+ infinirtGraphNode_t node;
38+ std::vector<char > log_buffer;
39+
40+ DeviceGraph () {
41+ log_buffer.resize (4 * 1024 );
42+ }
43+
44+ ~DeviceGraph () {
45+ if (exec) {
46+ infinirtGraphExecDestroy (exec);
47+ }
48+ if (graph) {
49+ infinirtGraphDestroy (graph);
50+ }
51+ }
52+
53+ void launch () {
54+ INFINICORE_CHECK_ERROR (infinirtGraphLuanch (exec, context::getStream ()));
55+ }
56+ };
57+
58+ Graph::Graph () {
59+ }
60+
3661void Graph::run () const {
37- for (auto &op : op_list_) {
38- op->run ();
62+ if (device_graph_ != nullptr && device_graph_.get ()->exec != nullptr ) {
63+ device_graph_.get ()->launch ();
64+ } else {
65+ for (auto &op : op_list_) {
66+ op->run ();
67+ }
3968 }
4069}
4170
4271void Graph::add_operator (std::shared_ptr<GraphOperator> op) {
4372 op_list_.push_back (op);
4473}
4574
75+ void Graph::instantiate () {
76+ // Reset device graph
77+ device_graph_ = std::make_unique<DeviceGraph>();
78+
79+ // warmup
80+ for (size_t iter = 0 ; iter < 5 ; ++iter) {
81+ this ->run ();
82+ }
83+ infinicore::context::syncStream ();
84+
85+ if (infinirtStreamBeginCapture (
86+ context::getStream (),
87+ INFINIRT_STREAM_CAPTURE_MODE_GLOBAL)
88+ != INFINI_STATUS_SUCCESS) {
89+ return ;
90+ }
91+
92+ // Run and record
93+ this ->run ();
94+
95+ if (infinirtStreamEndCapture (
96+ context::getStream (),
97+ &device_graph_.get ()->graph )
98+ != INFINI_STATUS_SUCCESS) {
99+ return ;
100+ }
101+
102+ if (infinirtGraphInstantiate (
103+ &device_graph_.get ()->exec ,
104+ device_graph_.get ()->graph ,
105+ &device_graph_.get ()->node ,
106+ device_graph_.get ()->log_buffer .data (),
107+ device_graph_.get ()->log_buffer .size ())
108+ != INFINI_STATUS_SUCCESS) {
109+ static bool warned_once = false ;
110+ if (!warned_once) {
111+ warned_once = true ;
112+ spdlog::warn (" Fail to instantiate device graph: {}" , std::string (device_graph_.get ()->log_buffer .data ()));
113+ }
114+ }
115+ }
116+
117+ Graph::~Graph () = default ;
118+
46119/* =========================
47120 * GraphManager
48121 * ========================= */
@@ -52,19 +125,26 @@ bool GraphManager::is_recording() const {
52125}
53126
54127void GraphManager::start_recording () {
128+ if (is_recording ()) {
129+ spdlog::warn (" Graph is already recording. Previous recording will be dropped." );
130+ }
55131 recording_ = true ;
56132 graph_ = std::make_shared<Graph>();
57133}
58134
59135void GraphManager::add_operator (std::shared_ptr<GraphOperator> op) {
60- INFINICORE_ASSERT (recording_ );
136+ INFINICORE_ASSERT (is_recording () );
61137
62138 graph_->add_operator (op);
63139}
64140
65141std::shared_ptr<Graph> GraphManager::stop_recording () {
66-
142+ if (!is_recording ()) {
143+ spdlog::warn (" Graph is not recording. Please start recording first." );
144+ return nullptr ;
145+ }
67146 recording_ = false ;
147+ graph_->instantiate ();
68148 return std::exchange (graph_, nullptr );
69149}
70150
0 commit comments