1- #include " layers_oneDNN/EWLayer_oneDNN .hpp"
1+ #include " layers_oneDNN/EwLayer_oneDnn .hpp"
22
33#include < iostream>
44#include < stdexcept>
55
66namespace it_lab_ai {
77
8- void EWLayer_oneDNN ::run (const std::vector<Tensor>& input,
9- std::vector<Tensor>& output) {
8+ void EwLayerOneDnn ::run (const std::vector<Tensor>& input,
9+ std::vector<Tensor>& output) {
1010 validate_input (input);
1111
1212 const Tensor& input_tensor = input[0 ];
13+ Type data_type = input_tensor.get_type ();
1314
1415 if (!initialized_) {
15- initialize_onednn (input_tensor.get_shape ());
16- }
17- if (input_tensor.get_type () != Type::kFloat ) {
18- throw std::runtime_error (" oneDNN EWLayer supports only float tensors" );
16+ initialize_onednn (input_tensor.get_shape (), data_type);
1917 }
2018
2119 try {
22- const std::vector<float >& input_data = *input_tensor.as <float >();
23- std::vector<float > output_data (input_data.size ());
24- dnnl::memory src_mem = dnnl::memory (memory_desc_, *engine_,
25- const_cast <float *>(input_data.data ()));
26- dnnl::memory dst_mem =
27- dnnl::memory (memory_desc_, *engine_, output_data.data ());
28- eltwise_prim_->execute (*stream_,
29- {{DNNL_ARG_SRC , src_mem}, {DNNL_ARG_DST , dst_mem}});
30- stream_->wait ();
31- output[0 ] = make_tensor (output_data, input_tensor.get_shape ());
20+ if (data_type == Type::kFloat ) {
21+ const std::vector<float >& input_data = *input_tensor.as <float >();
22+ std::vector<float > output_data (input_data.size ());
23+ dnnl::memory src_mem = dnnl::memory (
24+ memory_desc_, *engine_, const_cast <float *>(input_data.data ()));
25+ dnnl::memory dst_mem =
26+ dnnl::memory (memory_desc_, *engine_, output_data.data ());
27+ eltwise_prim_->execute (
28+ *stream_, {{DNNL_ARG_SRC , src_mem}, {DNNL_ARG_DST , dst_mem}});
29+ stream_->wait ();
30+ output[0 ] = make_tensor (output_data, input_tensor.get_shape ());
31+ } else if (data_type == Type::kInt ) {
32+ const std::vector<int >& input_data = *input_tensor.as <int >();
33+ std::vector<int > output_data (input_data.size ());
34+
35+ std::vector<float > float_input;
36+ float_input.reserve (input_data.size ());
37+ for (int val : input_data) {
38+ float_input.push_back (static_cast <float >(val));
39+ }
40+
41+ std::vector<float > float_output (input_data.size ());
42+
43+ dnnl::memory src_mem =
44+ dnnl::memory (memory_desc_, *engine_, float_input.data ());
45+ dnnl::memory dst_mem =
46+ dnnl::memory (memory_desc_, *engine_, float_output.data ());
47+ eltwise_prim_->execute (
48+ *stream_, {{DNNL_ARG_SRC , src_mem}, {DNNL_ARG_DST , dst_mem}});
49+ stream_->wait ();
50+
51+ for (size_t i = 0 ; i < float_output.size (); ++i) {
52+ output_data[i] = static_cast <int >(std::round (float_output[i]));
53+ }
54+ output[0 ] = make_tensor (output_data, input_tensor.get_shape ());
55+ } else {
56+ throw std::runtime_error (" EwLayerOneDnn: Unsupported data type" );
57+ }
3258
3359 } catch (const std::exception& e) {
3460 std::cerr << " oneDNN execution failed: " << e.what () << std::endl;
3561 throw ;
3662 }
3763}
3864
39- void EWLayer_oneDNN ::validate_input (const std::vector<Tensor>& input) const {
65+ void EwLayerOneDnn ::validate_input (const std::vector<Tensor>& input) const {
4066 if (input.size () != 1 ) {
41- throw std::runtime_error (" EWLayer_oneDNN : Expected exactly 1 input tensor" );
67+ throw std::runtime_error (" EwLayerOneDnn : Expected exactly 1 input tensor" );
4268 }
4369
4470 if (!is_function_supported (func_)) {
4571 throw std::invalid_argument (" Unsupported function for oneDNN: " + func_);
4672 }
73+
74+ Type data_type = input[0 ].get_type ();
75+ if (data_type != Type::kFloat && data_type != Type::kInt ) {
76+ throw std::runtime_error (
77+ " EwLayerOneDnn supports only float and int tensors" );
78+ }
4779}
4880
49- void EWLayer_oneDNN ::initialize_onednn (const Shape& shape) {
81+ void EwLayerOneDnn ::initialize_onednn (const Shape& shape, Type data_type ) {
5082 try {
5183 engine_ = std::make_unique<dnnl::engine>(dnnl::engine::kind::cpu, 0 );
5284 stream_ = std::make_unique<dnnl::stream>(*engine_);
@@ -55,6 +87,7 @@ void EWLayer_oneDNN::initialize_onednn(const Shape& shape) {
5587 for (size_t i = 0 ; i < shape.dims (); i++) {
5688 dims.push_back (static_cast <dnnl::memory::dim>(shape.at (i)));
5789 }
90+
5891 dnnl::memory::format_tag format;
5992 switch (dims.size ()) {
6093 case 1 :
@@ -77,16 +110,22 @@ void EWLayer_oneDNN::initialize_onednn(const Shape& shape) {
77110 std::to_string (dims.size ()));
78111 }
79112
80- memory_desc_ =
81- dnnl::memory::desc (dims, dnnl::memory::data_type::f32 , format);
113+ dnnl::memory::data_type dnnl_data_type;
114+ if (data_type == Type::kFloat ) {
115+ dnnl_data_type = dnnl::memory::data_type::f32 ;
116+ } else {
117+ dnnl_data_type = dnnl::memory::data_type::f32 ;
118+ }
119+
120+ memory_desc_ = dnnl::memory::desc (dims, dnnl_data_type, format);
82121
83122 dnnl::algorithm algo = get_algorithm ();
84123
85- float primitive_alpha = 0 .0f ;
86- float primitive_beta = 0 .0f ;
124+ float primitive_alpha = 0 .0F ;
125+ float primitive_beta = 0 .0F ;
87126
88127 if (func_ == " relu" ) {
89- primitive_alpha = 0 .0f ;
128+ primitive_alpha = 0 .0F ;
90129 } else if (func_ == " linear" ) {
91130 primitive_alpha = alpha_;
92131 primitive_beta = beta_;
@@ -100,34 +139,31 @@ void EWLayer_oneDNN::initialize_onednn(const Shape& shape) {
100139
101140 initialized_ = true ;
102141
103- for (size_t i = 0 ; i < dims.size (); ++i) {
104- std::cout << dims[i];
105- if (i < dims.size () - 1 ) std::cout << " , " ;
106- }
107- std::cout << " ]" << std::endl;
108-
109142 } catch (const std::exception& e) {
110143 std::cerr << " oneDNN initialization failed for function '" << func_
111144 << " ': " << e.what () << std::endl;
112145 throw ;
113146 }
114147}
115148
116- dnnl::algorithm EWLayer_oneDNN ::get_algorithm () const {
149+ dnnl::algorithm EwLayerOneDnn ::get_algorithm () const {
117150 if (func_ == " relu" ) {
118151 return dnnl::algorithm::eltwise_relu;
119- } else if (func_ == " tanh" ) {
152+ }
153+ if (func_ == " tanh" ) {
120154 return dnnl::algorithm::eltwise_tanh;
121- } else if (func_ == " sigmoid" ) {
155+ }
156+ if (func_ == " sigmoid" ) {
122157 return dnnl::algorithm::eltwise_logistic;
123- } else if (func_ == " linear" ) {
158+ }
159+ if (func_ == " linear" ) {
124160 return dnnl::algorithm::eltwise_linear;
125- } else {
126- throw std::invalid_argument (" Unsupported function for oneDNN: " + func_);
127161 }
162+
163+ throw std::invalid_argument (" Unsupported function for oneDNN: " + func_);
128164}
129165
130- bool EWLayer_oneDNN ::is_function_supported (const std::string& function) {
166+ bool EwLayerOneDnn ::is_function_supported (const std::string& function) {
131167 return (function == " relu" || function == " tanh" || function == " sigmoid" ||
132168 function == " linear" );
133169}
0 commit comments