|
3 | 3 | #include <unordered_map> |
4 | 4 |
|
5 | 5 | #include "build.hpp" |
| 6 | +#include "graph_transformations/graph_transformations.hpp" |
| 7 | +#include "layers_fused/ConvRelu.hpp" |
6 | 8 |
|
7 | 9 | namespace fs = std::filesystem; |
8 | 10 | using namespace it_lab_ai; |
9 | 11 |
|
| 12 | +namespace { |
| 13 | + |
| 14 | +enum class FusionMode { kOff, kPostops, kConvRelu }; |
| 15 | + |
| 16 | +FusionMode parse_fusion_mode(const std::string& value) { |
| 17 | + if (value == "off") { |
| 18 | + return FusionMode::kOff; |
| 19 | + } |
| 20 | + if (value == "postops") { |
| 21 | + return FusionMode::kPostops; |
| 22 | + } |
| 23 | + if (value == "convrelu") { |
| 24 | + return FusionMode::kConvRelu; |
| 25 | + } |
| 26 | + throw std::invalid_argument("Unknown fusion mode: " + value); |
| 27 | +} |
| 28 | + |
| 29 | +void apply_conv_relu_fusion(Graph& graph, Tensor& output, |
| 30 | + const RuntimeOptions& options) { |
| 31 | + if (options.backend == Backend::kOneDnn) { |
| 32 | + throw std::invalid_argument( |
| 33 | + "convrelu fusion is not supported with oneDNN backend"); |
| 34 | + } |
| 35 | + |
| 36 | + Graph subgraph; |
| 37 | + Tensor dummy_input = make_tensor(std::vector<int>({0})); |
| 38 | + auto conv = std::make_shared<ConvolutionalLayer>(); |
| 39 | + auto relu = std::make_shared<EWLayer>("relu"); |
| 40 | + subgraph.setInput(conv, dummy_input); |
| 41 | + subgraph.makeConnection(conv, relu); |
| 42 | + |
| 43 | + Graph fused_graph; |
| 44 | + auto fused_layer = std::make_shared<ConvReluLayer>(); |
| 45 | + changed_subgraphs(graph, subgraph, fused_layer, fused_graph, output, options); |
| 46 | + graph = std::move(fused_graph); |
| 47 | +} |
| 48 | + |
| 49 | +} // namespace |
| 50 | + |
10 | 51 | int main(int argc, char* argv[]) { |
11 | 52 | std::string model_name = "alexnet_mnist"; |
12 | 53 | RuntimeOptions options; |
| 54 | + FusionMode fusion_mode = FusionMode::kPostops; |
13 | 55 |
|
14 | 56 | for (int i = 1; i < argc; ++i) { |
15 | 57 | if (std::string(argv[i]) == "--model" && i + 1 < argc) { |
@@ -47,6 +89,8 @@ int main(int argc, char* argv[]) { |
47 | 89 | } |
48 | 90 | } else if (std::string(argv[i]) == "--threads" && i + 1 < argc) { |
49 | 91 | options.threads = std::stoi(argv[++i]); |
| 92 | + } else if (std::string(argv[i]) == "--fusion" && i + 1 < argc) { |
| 93 | + fusion_mode = parse_fusion_mode(argv[++i]); |
50 | 94 | } |
51 | 95 | } |
52 | 96 |
|
@@ -92,7 +136,11 @@ int main(int argc, char* argv[]) { |
92 | 136 | std::vector<float> vec(75, 3); |
93 | 137 | it_lab_ai::Tensor output = it_lab_ai::make_tensor(vec, sh1); |
94 | 138 | Graph graph; |
95 | | - build_graph_linear(graph, input, output, options, true); |
| 139 | + build_graph_linear(graph, input, output, options, true, |
| 140 | + fusion_mode == FusionMode::kPostops); |
| 141 | + if (fusion_mode == FusionMode::kConvRelu) { |
| 142 | + apply_conv_relu_fusion(graph, output, options); |
| 143 | + } |
96 | 144 |
|
97 | 145 | std::cout << "Starting inference..." << '\n'; |
98 | 146 | try { |
@@ -133,6 +181,9 @@ int main(int argc, char* argv[]) { |
133 | 181 |
|
134 | 182 | Graph graph; |
135 | 183 | build_graph(graph, input, output, json_path, options, false); |
| 184 | + if (fusion_mode == FusionMode::kConvRelu) { |
| 185 | + apply_conv_relu_fusion(graph, output, options); |
| 186 | + } |
136 | 187 |
|
137 | 188 | std::cout << "Starting inference..." << '\n'; |
138 | 189 | try { |
|
0 commit comments