1+ #pragma once
2+ #include < algorithm>
3+ #include < memory>
4+ #include < stdexcept>
5+ #include < utility>
6+ #include < vector>
7+
8+ #include " Tensor.hpp"
9+ #include " layers/Layer.hpp"
10+
11+ namespace it_lab_ai {
12+
13+ class BinaryOpLayer : public Layer {
14+ public:
15+ enum class Operation : uint8_t { kMul , kAdd , kSub , kDiv };
16+
17+ BinaryOpLayer () = default ;
18+ explicit BinaryOpLayer (Operation op) : op_(op) {}
19+
20+ static std::string get_name () { return " Binary Operation Layer" ; }
21+ void run (const Tensor& input, Tensor& output) override ;
22+ void run (const Tensor& A, const Tensor& B, Tensor& output);
23+ static bool is_scalar_tensor (const Tensor& t);
24+
25+ #ifdef ENABLE_STATISTIC_WEIGHTS
26+ Tensor get_weights () override {
27+ std::vector<int > v = {0 };
28+ return make_tensor (v);
29+ }
30+ #endif
31+
32+ private:
33+ Operation op_ = Operation::kMul ;
34+
35+ template <typename ValueType>
36+ void run_with_scalar_impl (const Tensor& input, ValueType scalar,
37+ Tensor& output) const ;
38+ template <typename ValueType>
39+ void run_broadcast_impl (const Tensor& A, const Tensor& B, Tensor& output,
40+ const Shape& output_shape) const ;
41+ void run_with_scalar (const Tensor& input, float scalar, Tensor& output) const ;
42+
43+ static bool can_broadcast (const Shape& shape_A, const Shape& shape_B);
44+ static Shape calculate_broadcasted_shape (const Shape& shape_A,
45+ const Shape& shape_B);
46+ static std::vector<size_t > get_strides (const Shape& shape);
47+ static size_t get_broadcasted_index (
48+ size_t flat_index, const Shape& input_shape, const Shape& output_shape,
49+ const std::vector<size_t >& input_strides,
50+ const std::vector<size_t >& output_strides);
51+
52+ template <typename ValueType>
53+ class BinaryOpLayerImpl ;
54+ };
55+
56+ } // namespace it_lab_ai
0 commit comments