1+ #include " layers_oneDNN/ConcatLayer.hpp"
2+
3+ #include < stdexcept>
4+
5+ namespace it_lab_ai {
6+
7+ void ConcatLayerOneDnn::run (const std::vector<Tensor>& input,
8+ std::vector<Tensor>& output) {
9+ validate_input (input);
10+
11+ if (input.size () == 1 ) {
12+ output = input;
13+ return ;
14+ }
15+
16+ Type type = input[0 ].get_type ();
17+
18+ bool need_reinit = !initialized_ || last_type_ != type ||
19+ last_shapes_.size () != input.size ();
20+
21+ if (!need_reinit) {
22+ for (size_t i = 0 ; i < input.size (); ++i) {
23+ if (last_shapes_[i] != input[i].get_shape ()) {
24+ need_reinit = true ;
25+ break ;
26+ }
27+ }
28+ }
29+
30+ if (need_reinit) {
31+ initialize_onednn (input);
32+ }
33+
34+ output.resize (1 );
35+
36+ if (type == Type::kFloat ) {
37+ for (size_t i = 0 ; i < input.size (); ++i) {
38+ if (last_type_ == Type::kFloat )
39+ src_mems_[i].set_data_handle (
40+ const_cast <float *>(input[i].as <float >()->data ()));
41+ else
42+ src_mems_[i].set_data_handle (
43+ const_cast <int *>(input[i].as <int >()->data ()));
44+
45+ args_[DNNL_ARG_MULTIPLE_SRC + i] = src_mems_[i];
46+ }
47+
48+ args_[DNNL_ARG_DST ] = dst_mem_;
49+
50+ concat_prim_->execute (*stream_, args_);
51+ stream_->wait ();
52+
53+ output[0 ] = make_tensor (dst_buffer_f32_, output_shape_);
54+ } else if (type == Type::kInt ) {
55+ for (size_t i = 0 ; i < input.size (); ++i) {
56+ src_mems_[i].set_data_handle (
57+ const_cast <int *>(input[i].as <int >()->data ()));
58+ args_[DNNL_ARG_MULTIPLE_SRC + i] = src_mems_[i];
59+ }
60+
61+ args_[DNNL_ARG_DST ] = dst_mem_;
62+
63+ concat_prim_->execute (*stream_, args_);
64+ stream_->wait ();
65+
66+ output[0 ] = make_tensor (dst_buffer_s32_, output_shape_);
67+ }
68+ }
69+
70+ void ConcatLayerOneDnn::validate_input (const std::vector<Tensor>& input) {
71+ Type type = input[0 ].get_type ();
72+ const Shape& base = input[0 ].get_shape ();
73+
74+ for (size_t i = 1 ; i < input.size (); ++i) {
75+ if (input[i].get_type () != type) {
76+ throw std::runtime_error (
77+ " ConcatLayerOneDnn: All tensors must have same type" );
78+ }
79+
80+ if (input[i].get_shape ().dims () != base.dims ()) {
81+ throw std::runtime_error (
82+ " ConcatLayerOneDnn: All tensors must have same rank" );
83+ }
84+ }
85+ }
86+
87+ void ConcatLayerOneDnn::initialize_onednn (const std::vector<Tensor>& input) {
88+ if (!engine_)
89+ engine_ = std::make_unique<dnnl::engine>(dnnl::engine::kind::cpu, 0 );
90+ if (!stream_) stream_ = std::make_unique<dnnl::stream>(*engine_);
91+
92+ size_t rank = input[0 ].get_shape ().dims ();
93+ int64_t axis = normalize_axis (axis_, rank);
94+
95+ last_type_ = input[0 ].get_type ();
96+ auto type = get_dnnl_data_type (last_type_);
97+
98+ auto layout = pick_format (rank);
99+
100+ src_mds_.clear ();
101+ for (const auto & t : input) {
102+ src_mds_.emplace_back (shape_to_dims (t.get_shape ()), type, layout);
103+ }
104+
105+ output_shape_ = calculate_output_shape (input, axis);
106+
107+ dst_md_ = dnnl::memory::desc (shape_to_dims (output_shape_), type, layout);
108+
109+ auto concat_pd =
110+ dnnl::concat::primitive_desc (*engine_, dst_md_, axis, src_mds_);
111+ concat_prim_ = std::make_unique<dnnl::concat>(concat_pd);
112+
113+ dst_md_ = concat_pd.dst_desc ();
114+ src_mds_.clear ();
115+ for (size_t i = 0 ; i < input.size (); ++i) {
116+ src_mds_.push_back (concat_pd.src_desc (i));
117+ }
118+
119+ size_t n = input.size ();
120+ src_mems_.resize (n);
121+ for (size_t i = 0 ; i < n; ++i) {
122+ src_mems_[i] = dnnl::memory (src_mds_[i], *engine_, nullptr );
123+ }
124+
125+ size_t out_size = output_shape_.count ();
126+ if (last_type_ == Type::kFloat ) {
127+ dst_buffer_f32_.resize (out_size);
128+ dst_mem_ = dnnl::memory (dst_md_, *engine_, dst_buffer_f32_.data ());
129+ } else {
130+ dst_buffer_s32_.resize (out_size);
131+ dst_mem_ = dnnl::memory (dst_md_, *engine_, dst_buffer_s32_.data ());
132+ }
133+
134+ args_.clear ();
135+ for (size_t i = 0 ; i < n; ++i)
136+ args_[DNNL_ARG_MULTIPLE_SRC + i] = src_mems_[i];
137+ args_[DNNL_ARG_DST ] = dst_mem_;
138+
139+ last_shapes_.clear ();
140+ for (const auto & t : input) last_shapes_.push_back (t.get_shape ());
141+
142+ initialized_ = true ;
143+ }
144+
145+ dnnl::memory::data_type ConcatLayerOneDnn::get_dnnl_data_type (Type type) {
146+ switch (type) {
147+ case Type::kFloat :
148+ return dnnl::memory::data_type::f32 ;
149+ case Type::kInt :
150+ return dnnl::memory::data_type::s32;
151+ default :
152+ throw std::runtime_error (" Unsupported data type for oneDNN" );
153+ }
154+ }
155+
156+ dnnl::memory::format_tag ConcatLayerOneDnn::pick_format (size_t ndims) {
157+ switch (ndims) {
158+ case 1 :
159+ return dnnl::memory::format_tag::a;
160+ case 2 :
161+ return dnnl::memory::format_tag::ab;
162+ case 3 :
163+ return dnnl::memory::format_tag::abc;
164+ case 4 :
165+ return dnnl::memory::format_tag::abcd;
166+ case 5 :
167+ return dnnl::memory::format_tag::abcde;
168+ default :
169+ return dnnl::memory::format_tag::any;
170+ }
171+ }
172+
173+ std::vector<dnnl::memory::dim> ConcatLayerOneDnn::shape_to_dims (
174+ const Shape& shape) {
175+ std::vector<dnnl::memory::dim> dims;
176+
177+ for (size_t i = 0 ; i < shape.dims (); ++i) {
178+ dims.push_back (static_cast <dnnl::memory::dim>(shape.at (i)));
179+ }
180+
181+ return dims;
182+ }
183+
184+ Shape ConcatLayerOneDnn::calculate_output_shape (
185+ const std::vector<Tensor>& inputs, int64_t axis) {
186+ const Shape& base = inputs[0 ].get_shape ();
187+
188+ std::vector<size_t > dims (base.dims ());
189+
190+ for (size_t i = 0 ; i < base.dims (); ++i) {
191+ dims[i] = base[i];
192+ }
193+
194+ dims[axis] = 0 ;
195+
196+ for (const auto & t : inputs) {
197+ dims[axis] += t.get_shape ()[axis];
198+ }
199+
200+ return Shape (dims);
201+ }
202+
203+ int64_t ConcatLayerOneDnn::normalize_axis (int64_t axis, size_t rank) {
204+ if (axis < 0 ) axis += rank;
205+
206+ if (axis < 0 || axis >= static_cast <int64_t >(rank)) {
207+ throw std::runtime_error (" ConcatLayerOneDnn: axis out of range" );
208+ }
209+
210+ return axis;
211+ }
212+
213+ } // namespace it_lab_ai
0 commit comments