1+ #pragma once
2+ #include < cstdint>
3+ #include < dnnl.hpp>
4+ #include < memory>
5+ #include < vector>
6+
7+ #include " layers/Layer.hpp"
8+ #include " layers/ReduceLayer.hpp"
9+ #include " layers/Tensor.hpp"
10+
11+ namespace it_lab_ai {
12+
13+ class ReduceLayerOneDnn : public Layer {
14+ public:
15+ ReduceLayerOneDnn (ReduceLayer::Operation op, int64_t keepdims,
16+ const std::vector<int64_t >& axes)
17+ : Layer(kReduce ), op_(op), keepdims_(keepdims), axes_(axes) {}
18+
19+ explicit ReduceLayerOneDnn (int64_t keepdims = 0 ,
20+ const std::vector<int64_t >& axes = {})
21+ : ReduceLayerOneDnn(ReduceLayer::Operation::kSum , keepdims, axes) {}
22+
23+ void run (const std::vector<Tensor>& input,
24+ std::vector<Tensor>& output) override ;
25+
26+ void set_axes (const std::vector<int64_t >& axes) {
27+ axes_ = axes;
28+ initialized_ = false ;
29+ }
30+
31+ void set_keepdims (int64_t keepdims) {
32+ keepdims_ = keepdims;
33+ initialized_ = false ;
34+ }
35+
36+ void set_operation (ReduceLayer::Operation op) {
37+ op_ = op;
38+ initialized_ = false ;
39+ }
40+
41+ #ifdef ENABLE_STATISTIC_WEIGHTS
42+ Tensor get_weights () override { return Tensor (); }
43+ #endif
44+
45+ private:
46+ ReduceLayer::Operation op_;
47+ int64_t keepdims_;
48+ std::vector<int64_t > axes_;
49+ std::vector<int64_t > normalized_axes_;
50+ std::vector<int64_t > last_axes_;
51+
52+ bool initialized_ = false ;
53+ Shape last_input_shape_;
54+ Type last_type_;
55+
56+ std::unique_ptr<dnnl::engine> engine_;
57+ std::unique_ptr<dnnl::stream> stream_;
58+ std::unique_ptr<dnnl::reduction> reduction_prim_;
59+
60+ dnnl::memory::desc src_md_;
61+ dnnl::memory::desc dst_md_;
62+ Shape output_shape_;
63+
64+ void initialize_onednn (const Tensor& input);
65+ static void validate_input (const std::vector<Tensor>& input);
66+ [[nodiscard]] static dnnl::memory::data_type get_dnnl_data_type (Type type);
67+ [[nodiscard]] static dnnl::algorithm get_dnnl_algorithm (
68+ ReduceLayer::Operation op);
69+ [[nodiscard]] static dnnl::memory::format_tag pick_format (size_t ndims);
70+ static void normalize_axes (const Shape& input_shape,
71+ std::vector<int64_t >& axes);
72+ [[nodiscard]] Shape calculate_output_shape (
73+ const Shape& input_shape, const std::vector<int64_t >& axes) const ;
74+
75+ [[nodiscard]] static std::vector<dnnl::memory::dim> shape_to_dims (
76+ const Shape& shape);
77+ template <typename T>
78+ std::vector<T> remove_unit_dims (const std::vector<T>& src_data,
79+ const Shape& src_shape,
80+ const Shape& dst_shape);
81+ };
82+
83+ } // namespace it_lab_ai
0 commit comments