Skip to content

Commit 1966fe8

Browse files
authored
BinaryLayer && OneDNN (#267)
1 parent 5e23309 commit 1966fe8

5 files changed

Lines changed: 750 additions & 1 deletion

File tree

app/Graph/build.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ ParseResult parse_json_model(RuntimeOptions options,
617617
continue;
618618
}
619619

620-
auto bin_layer = std::make_shared<it_lab_ai::BinaryOpLayer>(op);
620+
auto bin_layer = LayerFactory::createBinaryLayer(op, options);
621621
layer = bin_layer;
622622
}
623623
} else if (layer_type == "Gemm") {

app/Graph/build.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "layers/SplitLayer.hpp"
3333
#include "layers/Tensor.hpp"
3434
#include "layers/TransposeLayer.hpp"
35+
#include "layers_oneDNN/BinaryOpLayer.hpp"
3536
#include "layers_oneDNN/ConvLayer.hpp"
3637
#include "layers_oneDNN/EWLayer.hpp"
3738
#include "layers_oneDNN/PoolingLayer.hpp"
@@ -102,6 +103,15 @@ class LayerFactory {
102103
bias, group, useLegacyImpl);
103104
}
104105

106+
static std::shared_ptr<Layer> createBinaryLayer(
107+
const it_lab_ai::BinaryOpLayer::Operation op,
108+
const RuntimeOptions& options) {
109+
if (options.backend == Backend::kOneDnn) {
110+
return std::make_shared<it_lab_ai::BinaryOpLayerOneDnn>(op);
111+
}
112+
return std::make_shared<it_lab_ai::BinaryOpLayer>(op);
113+
}
114+
105115
static std::shared_ptr<Layer> createReduceLayer(
106116
ReduceLayer::Operation op, int64_t keepdims,
107117
const std::vector<int64_t>& axes, const RuntimeOptions& options) {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#pragma once
2+
#include <dnnl.hpp>
3+
#include <memory>
4+
#include <string>
5+
#include <vector>
6+
7+
#include "layers/BinaryOpLayer.hpp"
8+
#include "layers/Layer.hpp"
9+
#include "layers/Tensor.hpp"
10+
11+
namespace it_lab_ai {
12+
13+
class BinaryOpLayerOneDnn : public Layer {
14+
public:
15+
BinaryOpLayerOneDnn()
16+
: Layer(kBinaryOp), op_(BinaryOpLayer::Operation::kMul) {}
17+
explicit BinaryOpLayerOneDnn(BinaryOpLayer::Operation op)
18+
: Layer(kBinaryOp), op_(op) {}
19+
20+
void run(const std::vector<Tensor>& input,
21+
std::vector<Tensor>& output) override;
22+
23+
void set_operation(BinaryOpLayer::Operation op) {
24+
op_ = op;
25+
initialized_ = false;
26+
}
27+
28+
#ifdef ENABLE_STATISTIC_WEIGHTS
29+
Tensor get_weights() override {
30+
std::vector<int> v = {0};
31+
Tensor a = make_tensor(v);
32+
return a;
33+
}
34+
#endif
35+
36+
private:
37+
BinaryOpLayer::Operation op_;
38+
bool initialized_ = false;
39+
Shape last_shape_a_;
40+
Shape last_shape_b_;
41+
Type last_type_;
42+
43+
std::unique_ptr<dnnl::engine> engine_;
44+
std::unique_ptr<dnnl::stream> stream_;
45+
std::unique_ptr<dnnl::binary> binary_prim_;
46+
dnnl::memory::desc src0_md_;
47+
dnnl::memory::desc src1_md_;
48+
dnnl::memory::desc dst_md_;
49+
Shape output_shape_;
50+
51+
void initialize_onednn(const Tensor& A, const Tensor& B);
52+
static void validate_input(const std::vector<Tensor>& input);
53+
[[nodiscard]] static dnnl::memory::data_type get_dnnl_data_type(Type type);
54+
[[nodiscard]] static dnnl::algorithm get_dnnl_algorithm(
55+
BinaryOpLayer::Operation op);
56+
[[nodiscard]] static Shape calculate_output_shape(const Shape& shape_a,
57+
const Shape& shape_b);
58+
[[nodiscard]] static bool can_broadcast(const Shape& shape_a,
59+
const Shape& shape_b);
60+
[[nodiscard]] static dnnl::memory::format_tag pick_format(size_t ndims);
61+
[[nodiscard]] static std::vector<dnnl::memory::dim> shape_to_dims(
62+
const Shape& shape);
63+
};
64+
65+
} // namespace it_lab_ai
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
#include "layers_oneDNN/BinaryOpLayer.hpp"
2+
3+
#include <algorithm>
4+
#include <iostream>
5+
#include <stdexcept>
6+
7+
namespace it_lab_ai {
8+
9+
void BinaryOpLayerOneDnn::run(const std::vector<Tensor>& input,
10+
std::vector<Tensor>& output) {
11+
validate_input(input);
12+
13+
const Tensor& a = input[0];
14+
const Tensor& b = input[1];
15+
Type type = a.get_type();
16+
17+
bool need_reinit = !initialized_ || last_type_ != type ||
18+
last_shape_a_ != a.get_shape() ||
19+
last_shape_b_ != b.get_shape();
20+
21+
if (need_reinit) {
22+
initialize_onednn(a, b);
23+
}
24+
25+
output.resize(1);
26+
output_shape_ = calculate_output_shape(a.get_shape(), b.get_shape());
27+
28+
if (type == Type::kFloat) {
29+
const auto& src0_data = *a.as<float>();
30+
const auto& src1_data = *b.as<float>();
31+
std::vector<float> dst_data(output_shape_.count());
32+
33+
dnnl::memory src0_mem(src0_md_, *engine_,
34+
const_cast<float*>(src0_data.data()));
35+
dnnl::memory src1_mem(src1_md_, *engine_,
36+
const_cast<float*>(src1_data.data()));
37+
dnnl::memory dst_mem(dst_md_, *engine_, dst_data.data());
38+
39+
binary_prim_->execute(*stream_, {{DNNL_ARG_SRC_0, src0_mem},
40+
{DNNL_ARG_SRC_1, src1_mem},
41+
{DNNL_ARG_DST, dst_mem}});
42+
43+
stream_->wait();
44+
output[0] = make_tensor(dst_data, output_shape_);
45+
} else if (type == Type::kInt) {
46+
const auto& src0_data = *a.as<int>();
47+
const auto& src1_data = *b.as<int>();
48+
std::vector<int> dst_data(output_shape_.count());
49+
50+
dnnl::memory src0_mem(src0_md_, *engine_,
51+
const_cast<int*>(src0_data.data()));
52+
dnnl::memory src1_mem(src1_md_, *engine_,
53+
const_cast<int*>(src1_data.data()));
54+
dnnl::memory dst_mem(dst_md_, *engine_, dst_data.data());
55+
56+
binary_prim_->execute(*stream_, {{DNNL_ARG_SRC_0, src0_mem},
57+
{DNNL_ARG_SRC_1, src1_mem},
58+
{DNNL_ARG_DST, dst_mem}});
59+
60+
stream_->wait();
61+
output[0] = make_tensor(dst_data, output_shape_);
62+
}
63+
}
64+
65+
void BinaryOpLayerOneDnn::validate_input(const std::vector<Tensor>& input) {
66+
if (input.size() != 2) {
67+
throw std::runtime_error(
68+
"BinaryOpLayerOneDnn: Expected exactly 2 input tensors");
69+
}
70+
71+
if (input[0].get_type() != input[1].get_type()) {
72+
throw std::runtime_error(
73+
"BinaryOpLayerOneDnn: Input tensors must have the same type");
74+
}
75+
76+
const Shape& shape_a = input[0].get_shape();
77+
const Shape& shape_b = input[1].get_shape();
78+
79+
if (!can_broadcast(shape_a, shape_b)) {
80+
throw std::runtime_error(
81+
"BinaryOpLayerOneDnn: Incompatible shapes for broadcasting");
82+
}
83+
}
84+
85+
Shape BinaryOpLayerOneDnn::calculate_output_shape(const Shape& shape_a,
86+
const Shape& shape_b) {
87+
size_t dims_a = shape_a.dims();
88+
size_t dims_b = shape_b.dims();
89+
size_t max_dims = std::max(dims_a, dims_b);
90+
Shape result(max_dims);
91+
92+
for (size_t i = 0; i < max_dims; ++i) {
93+
size_t idx_a = dims_a - i - 1;
94+
size_t idx_b = dims_b - i - 1;
95+
size_t idx_result = max_dims - i - 1;
96+
97+
size_t dim_a = (i < dims_a) ? shape_a[idx_a] : 1;
98+
size_t dim_b = (i < dims_b) ? shape_b[idx_b] : 1;
99+
100+
if ((dim_a != dim_b) && (dim_a != 1) && (dim_b != 1)) {
101+
throw std::runtime_error("BinaryOpLayerOneDnn: Incompatible dimensions");
102+
}
103+
result[idx_result] = std::max(dim_a, dim_b);
104+
}
105+
106+
return result;
107+
}
108+
109+
bool BinaryOpLayerOneDnn::can_broadcast(const Shape& shape_a,
110+
const Shape& shape_b) {
111+
size_t dims_a = shape_a.dims();
112+
size_t dims_b = shape_b.dims();
113+
size_t max_dims = std::max(dims_a, dims_b);
114+
115+
for (size_t i = 0; i < max_dims; ++i) {
116+
size_t idx_a = dims_a - i - 1;
117+
size_t idx_b = dims_b - i - 1;
118+
119+
size_t dim_a = (i < dims_a) ? shape_a[idx_a] : 1;
120+
size_t dim_b = (i < dims_b) ? shape_b[idx_b] : 1;
121+
122+
if (dim_a != dim_b && dim_a != 1 && dim_b != 1) {
123+
return false;
124+
}
125+
}
126+
127+
return true;
128+
}
129+
130+
void BinaryOpLayerOneDnn::initialize_onednn(const Tensor& A, const Tensor& B) {
131+
engine_ = std::make_unique<dnnl::engine>(dnnl::engine::kind::cpu, 0);
132+
stream_ = std::make_unique<dnnl::stream>(*engine_);
133+
134+
const Shape& shape_a = A.get_shape();
135+
const Shape& shape_b = B.get_shape();
136+
output_shape_ = calculate_output_shape(shape_a, shape_b);
137+
138+
auto dnnl_type = get_dnnl_data_type(A.get_type());
139+
140+
auto dims_a = shape_to_dims(shape_a);
141+
auto dims_b = shape_to_dims(shape_b);
142+
auto dims_output = shape_to_dims(output_shape_);
143+
144+
size_t ndims = output_shape_.dims();
145+
auto format = pick_format(ndims);
146+
147+
src0_md_ = dnnl::memory::desc(dims_a, dnnl_type, format);
148+
src1_md_ = dnnl::memory::desc(dims_b, dnnl_type, format);
149+
dst_md_ = dnnl::memory::desc(dims_output, dnnl_type, format);
150+
151+
try {
152+
auto binary_pd = dnnl::binary::primitive_desc(
153+
*engine_, get_dnnl_algorithm(op_), src0_md_, src1_md_, dst_md_);
154+
155+
binary_prim_ = std::make_unique<dnnl::binary>(binary_pd);
156+
} catch (const dnnl::error& e) {
157+
std::cerr << "Error creating binary primitive: " << e.what() << '\n';
158+
throw std::runtime_error("Failed to create binary primitive: " +
159+
std::string(e.what()));
160+
}
161+
162+
last_shape_a_ = shape_a;
163+
last_shape_b_ = shape_b;
164+
last_type_ = A.get_type();
165+
initialized_ = true;
166+
}
167+
168+
dnnl::memory::data_type BinaryOpLayerOneDnn::get_dnnl_data_type(Type type) {
169+
switch (type) {
170+
case Type::kFloat:
171+
return dnnl::memory::data_type::f32;
172+
case Type::kInt:
173+
return dnnl::memory::data_type::s32;
174+
default:
175+
throw std::runtime_error("Unsupported data type for oneDNN");
176+
}
177+
}
178+
179+
dnnl::algorithm BinaryOpLayerOneDnn::get_dnnl_algorithm(
180+
BinaryOpLayer::Operation op) {
181+
switch (op) {
182+
case BinaryOpLayer::Operation::kAdd:
183+
return dnnl::algorithm::binary_add;
184+
case BinaryOpLayer::Operation::kMul:
185+
return dnnl::algorithm::binary_mul;
186+
default:
187+
throw std::invalid_argument("Unsupported binary operation for oneDNN");
188+
}
189+
}
190+
191+
dnnl::memory::format_tag BinaryOpLayerOneDnn::pick_format(size_t ndims) {
192+
switch (ndims) {
193+
case 0:
194+
case 1:
195+
return dnnl::memory::format_tag::a;
196+
case 2:
197+
return dnnl::memory::format_tag::ab;
198+
case 3:
199+
return dnnl::memory::format_tag::abc;
200+
case 4:
201+
return dnnl::memory::format_tag::abcd;
202+
case 5:
203+
return dnnl::memory::format_tag::abcde;
204+
default:
205+
return dnnl::memory::format_tag::any;
206+
}
207+
}
208+
209+
std::vector<dnnl::memory::dim> BinaryOpLayerOneDnn::shape_to_dims(
210+
const Shape& shape) {
211+
std::vector<dnnl::memory::dim> dims;
212+
for (size_t i = 0; i < shape.dims(); ++i) {
213+
dims.push_back(static_cast<dnnl::memory::dim>(shape.at(i)));
214+
}
215+
216+
if (dims.empty()) {
217+
dims.push_back(1);
218+
}
219+
220+
return dims;
221+
}
222+
223+
} // namespace it_lab_ai

0 commit comments

Comments
 (0)