Skip to content

Commit 5cb9583

Browse files
authored
Concat&&OneDNN (#285)
1 parent dd595d7 commit 5cb9583

5 files changed

Lines changed: 852 additions & 1 deletion

File tree

app/Graph/build.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ ParseResult parse_json_model(RuntimeOptions options,
519519
concat_connections[layer_name].push_back(base_input_name);
520520
}
521521
}
522-
auto concat_layer = std::make_shared<it_lab_ai::ConcatLayer>(axis);
522+
auto concat_layer = LayerFactory::createConcatLayer(axis, options);
523523
layer = concat_layer;
524524
concat_connected_inputs[layer_name] = std::unordered_set<std::string>();
525525
} else if (layer_type == "Split") {

app/Graph/build.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "layers/Tensor.hpp"
3434
#include "layers/TransposeLayer.hpp"
3535
#include "layers_oneDNN/BinaryOpLayer.hpp"
36+
#include "layers_oneDNN/ConcatLayer.hpp"
3637
#include "layers_oneDNN/ConvLayer.hpp"
3738
#include "layers_oneDNN/EWLayer.hpp"
3839
#include "layers_oneDNN/PoolingLayer.hpp"
@@ -134,6 +135,14 @@ class LayerFactory {
134135
return std::make_shared<PoolingLayer>(shape, strides, pads, dilations,
135136
ceil_mode, PoolType);
136137
}
138+
139+
static std::shared_ptr<Layer> createConcatLayer(
140+
int64_t axis, const RuntimeOptions& options) {
141+
if (options.backend == Backend::kOneDnn) {
142+
return std::make_shared<ConcatLayerOneDnn>(axis);
143+
}
144+
return std::make_shared<ConcatLayer>(axis);
145+
}
137146
};
138147

139148
} // namespace it_lab_ai
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#pragma once
2+
3+
#include <dnnl.hpp>
4+
#include <memory>
5+
#include <vector>
6+
7+
#include "layers/Layer.hpp"
8+
#include "layers/Tensor.hpp"
9+
10+
namespace it_lab_ai {
11+
12+
class ConcatLayerOneDnn : public Layer {
13+
public:
14+
explicit ConcatLayerOneDnn(int64_t axis = 0) : Layer(kConcat), axis_(axis) {}
15+
16+
ConcatLayerOneDnn(const ConcatLayerOneDnn& c)
17+
: Layer(kConcat), axis_(c.axis_) {}
18+
19+
void run(const std::vector<Tensor>& input,
20+
std::vector<Tensor>& output) override;
21+
22+
#ifdef ENABLE_STATISTIC_WEIGHTS
23+
Tensor get_weights() override {
24+
std::vector<int> v = {0};
25+
return make_tensor(v);
26+
}
27+
#endif
28+
29+
private:
30+
int64_t axis_;
31+
32+
bool initialized_ = false;
33+
Type last_type_;
34+
std::vector<Shape> last_shapes_;
35+
36+
std::unique_ptr<dnnl::engine> engine_;
37+
std::unique_ptr<dnnl::stream> stream_;
38+
std::unique_ptr<dnnl::concat> concat_prim_;
39+
40+
std::vector<dnnl::memory::desc> src_mds_;
41+
dnnl::memory::desc dst_md_;
42+
43+
Shape output_shape_;
44+
45+
std::vector<dnnl::memory> src_mems_;
46+
dnnl::memory dst_mem_;
47+
std::unordered_map<int, dnnl::memory> args_;
48+
49+
std::vector<float> dst_buffer_f32_;
50+
std::vector<int> dst_buffer_s32_;
51+
52+
void initialize_onednn(const std::vector<Tensor>& input);
53+
54+
static void validate_input(const std::vector<Tensor>& input);
55+
56+
[[nodiscard]] static dnnl::memory::data_type get_dnnl_data_type(Type type);
57+
58+
[[nodiscard]] static dnnl::memory::format_tag pick_format(size_t ndims);
59+
60+
[[nodiscard]] static std::vector<dnnl::memory::dim> shape_to_dims(
61+
const Shape& shape);
62+
63+
[[nodiscard]] static Shape calculate_output_shape(
64+
const std::vector<Tensor>& inputs, int64_t axis);
65+
66+
[[nodiscard]] static int64_t normalize_axis(int64_t axis, size_t rank);
67+
};
68+
69+
} // namespace it_lab_ai

src/layers_oneDNN/ConcatLayer.cpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#include "layers_oneDNN/ConcatLayer.hpp"
2+
3+
#include <stdexcept>
4+
5+
namespace it_lab_ai {
6+
7+
void ConcatLayerOneDnn::run(const std::vector<Tensor>& input,
8+
std::vector<Tensor>& output) {
9+
validate_input(input);
10+
11+
if (input.size() == 1) {
12+
output = input;
13+
return;
14+
}
15+
16+
Type type = input[0].get_type();
17+
18+
bool need_reinit = !initialized_ || last_type_ != type ||
19+
last_shapes_.size() != input.size();
20+
21+
if (!need_reinit) {
22+
for (size_t i = 0; i < input.size(); ++i) {
23+
if (last_shapes_[i] != input[i].get_shape()) {
24+
need_reinit = true;
25+
break;
26+
}
27+
}
28+
}
29+
30+
if (need_reinit) {
31+
initialize_onednn(input);
32+
}
33+
34+
output.resize(1);
35+
36+
if (type == Type::kFloat) {
37+
for (size_t i = 0; i < input.size(); ++i) {
38+
if (last_type_ == Type::kFloat) {
39+
src_mems_[i].set_data_handle(
40+
const_cast<float*>(input[i].as<float>()->data()));
41+
} else {
42+
src_mems_[i].set_data_handle(
43+
const_cast<int*>(input[i].as<int>()->data()));
44+
}
45+
46+
args_[DNNL_ARG_MULTIPLE_SRC + i] = src_mems_[i];
47+
}
48+
49+
args_[DNNL_ARG_DST] = dst_mem_;
50+
51+
concat_prim_->execute(*stream_, args_);
52+
stream_->wait();
53+
54+
output[0] = make_tensor(dst_buffer_f32_, output_shape_);
55+
} else if (type == Type::kInt) {
56+
for (size_t i = 0; i < input.size(); ++i) {
57+
src_mems_[i].set_data_handle(
58+
const_cast<int*>(input[i].as<int>()->data()));
59+
args_[DNNL_ARG_MULTIPLE_SRC + i] = src_mems_[i];
60+
}
61+
62+
args_[DNNL_ARG_DST] = dst_mem_;
63+
64+
concat_prim_->execute(*stream_, args_);
65+
stream_->wait();
66+
67+
output[0] = make_tensor(dst_buffer_s32_, output_shape_);
68+
}
69+
}
70+
71+
void ConcatLayerOneDnn::validate_input(const std::vector<Tensor>& input) {
72+
Type type = input[0].get_type();
73+
const Shape& base = input[0].get_shape();
74+
75+
for (size_t i = 1; i < input.size(); ++i) {
76+
if (input[i].get_type() != type) {
77+
throw std::runtime_error(
78+
"ConcatLayerOneDnn: All tensors must have same type");
79+
}
80+
81+
if (input[i].get_shape().dims() != base.dims()) {
82+
throw std::runtime_error(
83+
"ConcatLayerOneDnn: All tensors must have same rank");
84+
}
85+
}
86+
}
87+
88+
void ConcatLayerOneDnn::initialize_onednn(const std::vector<Tensor>& input) {
89+
if (!engine_) {
90+
engine_ = std::make_unique<dnnl::engine>(dnnl::engine::kind::cpu, 0);
91+
}
92+
if (!stream_) {
93+
stream_ = std::make_unique<dnnl::stream>(*engine_);
94+
}
95+
96+
size_t rank = input[0].get_shape().dims();
97+
int64_t axis = normalize_axis(axis_, rank);
98+
99+
last_type_ = input[0].get_type();
100+
auto type = get_dnnl_data_type(last_type_);
101+
102+
auto layout = pick_format(rank);
103+
104+
src_mds_.clear();
105+
for (const auto& t : input) {
106+
src_mds_.emplace_back(shape_to_dims(t.get_shape()), type, layout);
107+
}
108+
109+
output_shape_ = calculate_output_shape(input, axis);
110+
111+
dst_md_ = dnnl::memory::desc(shape_to_dims(output_shape_), type, layout);
112+
113+
auto concat_pd =
114+
dnnl::concat::primitive_desc(*engine_, dst_md_, axis, src_mds_);
115+
concat_prim_ = std::make_unique<dnnl::concat>(concat_pd);
116+
117+
dst_md_ = concat_pd.dst_desc();
118+
src_mds_.clear();
119+
for (size_t i = 0; i < input.size(); ++i) {
120+
src_mds_.push_back(concat_pd.src_desc(i));
121+
}
122+
123+
size_t n = input.size();
124+
src_mems_.resize(n);
125+
for (size_t i = 0; i < n; ++i) {
126+
src_mems_[i] = dnnl::memory(src_mds_[i], *engine_, nullptr);
127+
}
128+
129+
size_t out_size = output_shape_.count();
130+
if (last_type_ == Type::kFloat) {
131+
dst_buffer_f32_.resize(out_size);
132+
dst_mem_ = dnnl::memory(dst_md_, *engine_, dst_buffer_f32_.data());
133+
} else {
134+
dst_buffer_s32_.resize(out_size);
135+
dst_mem_ = dnnl::memory(dst_md_, *engine_, dst_buffer_s32_.data());
136+
}
137+
138+
args_.clear();
139+
for (size_t i = 0; i < n; ++i) {
140+
args_[DNNL_ARG_MULTIPLE_SRC + i] = src_mems_[i];
141+
}
142+
args_[DNNL_ARG_DST] = dst_mem_;
143+
144+
last_shapes_.clear();
145+
for (const auto& t : input) {
146+
last_shapes_.push_back(t.get_shape());
147+
}
148+
149+
initialized_ = true;
150+
}
151+
152+
dnnl::memory::data_type ConcatLayerOneDnn::get_dnnl_data_type(Type type) {
153+
switch (type) {
154+
case Type::kFloat:
155+
return dnnl::memory::data_type::f32;
156+
case Type::kInt:
157+
return dnnl::memory::data_type::s32;
158+
default:
159+
throw std::runtime_error("Unsupported data type for oneDNN");
160+
}
161+
}
162+
163+
dnnl::memory::format_tag ConcatLayerOneDnn::pick_format(size_t ndims) {
164+
switch (ndims) {
165+
case 1:
166+
return dnnl::memory::format_tag::a;
167+
case 2:
168+
return dnnl::memory::format_tag::ab;
169+
case 3:
170+
return dnnl::memory::format_tag::abc;
171+
case 4:
172+
return dnnl::memory::format_tag::abcd;
173+
case 5:
174+
return dnnl::memory::format_tag::abcde;
175+
default:
176+
return dnnl::memory::format_tag::any;
177+
}
178+
}
179+
180+
std::vector<dnnl::memory::dim> ConcatLayerOneDnn::shape_to_dims(
181+
const Shape& shape) {
182+
std::vector<dnnl::memory::dim> dims;
183+
184+
for (size_t i = 0; i < shape.dims(); ++i) {
185+
dims.push_back(static_cast<dnnl::memory::dim>(shape.at(i)));
186+
}
187+
188+
return dims;
189+
}
190+
191+
Shape ConcatLayerOneDnn::calculate_output_shape(
192+
const std::vector<Tensor>& inputs, int64_t axis) {
193+
const Shape& base = inputs[0].get_shape();
194+
195+
std::vector<size_t> dims(base.dims());
196+
197+
for (size_t i = 0; i < base.dims(); ++i) {
198+
dims[i] = base[i];
199+
}
200+
201+
dims[axis] = 0;
202+
203+
for (const auto& t : inputs) {
204+
dims[axis] += t.get_shape()[axis];
205+
}
206+
207+
return Shape(dims);
208+
}
209+
210+
int64_t ConcatLayerOneDnn::normalize_axis(int64_t axis, size_t rank) {
211+
if (axis < 0) {
212+
axis += rank;
213+
}
214+
215+
if (axis < 0 || axis >= static_cast<int64_t>(rank)) {
216+
throw std::runtime_error("ConcatLayerOneDnn: axis out of range");
217+
}
218+
219+
return axis;
220+
}
221+
222+
} // namespace it_lab_ai

0 commit comments

Comments
 (0)