1+ #ifndef INFINI_OPS_BASE_UPSAMPLE_NEAREST1D_BACKWARD_H_
2+ #define INFINI_OPS_BASE_UPSAMPLE_NEAREST1D_BACKWARD_H_
3+
4+ #include " operator.h"
5+
6+ namespace infini ::ops {
7+
8+ class UpsampleNearest1dBackward : public Operator <UpsampleNearest1dBackward> {
9+ public:
10+ UpsampleNearest1dBackward (const Tensor grad_output,
11+ const std::vector<int64_t > output_size,
12+ const std::vector<int64_t > input_size,
13+ Tensor grad_input)
14+ : grad_output_shape_{grad_output.shape ()},
15+ grad_output_strides_{grad_output.strides ()},
16+ grad_output_type_{grad_output.dtype ()},
17+ grad_input_shape_{grad_input.shape ()},
18+ grad_input_strides_{grad_input.strides ()},
19+ grad_input_type_{grad_input.dtype ()},
20+ output_size_{output_size},
21+ input_size_{input_size},
22+ device_index_{grad_input.device ().index ()} {}
23+
24+ virtual void operator ()(const Tensor grad_output,
25+ const std::vector<int64_t > output_size,
26+ const std::vector<int64_t > input_size,
27+ Tensor grad_input) const = 0;
28+
29+ protected:
30+ Tensor::Shape grad_output_shape_;
31+
32+ Tensor::Strides grad_output_strides_;
33+
34+ DataType grad_output_type_;
35+
36+ Tensor::Shape grad_input_shape_;
37+
38+ Tensor::Strides grad_input_strides_;
39+
40+ DataType grad_input_type_;
41+
42+ std::vector<int64_t > output_size_{};
43+
44+ std::vector<int64_t > input_size_{};
45+
46+ int device_index_{0 };
47+ };
48+
49+ } // namespace infini::ops
50+
51+ #endif
0 commit comments