Skip to content

Commit 9e03332

Browse files
committed
feat: add max_pool3d_with_indices_backward base
1 parent cfea02e commit 9e03332

1 file changed

Lines changed: 85 additions & 0 deletions

File tree

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#ifndef INFINI_OPS_BASE_MAX_POOL3D_WITH_INDICES_BACKWARD_H_
2+
#define INFINI_OPS_BASE_MAX_POOL3D_WITH_INDICES_BACKWARD_H_
3+
4+
#include "operator.h"
5+
6+
namespace infini::ops {
7+
8+
class MaxPool3dWithIndicesBackward
9+
: public Operator<MaxPool3dWithIndicesBackward> {
10+
public:
11+
MaxPool3dWithIndicesBackward(const Tensor grad_output, const Tensor input,
12+
const std::vector<int64_t> kernel_size,
13+
const std::vector<int64_t> stride,
14+
const std::vector<int64_t> padding,
15+
const std::vector<int64_t> dilation,
16+
const bool ceil_mode, const Tensor indices,
17+
Tensor grad_input)
18+
: grad_output_shape_{grad_output.shape()},
19+
grad_output_strides_{grad_output.strides()},
20+
grad_output_type_{grad_output.dtype()},
21+
input_shape_{input.shape()},
22+
input_strides_{input.strides()},
23+
input_type_{input.dtype()},
24+
indices_shape_{indices.shape()},
25+
indices_strides_{indices.strides()},
26+
indices_type_{indices.dtype()},
27+
grad_input_shape_{grad_input.shape()},
28+
grad_input_strides_{grad_input.strides()},
29+
grad_input_type_{grad_input.dtype()},
30+
kernel_size_{kernel_size},
31+
stride_{stride},
32+
padding_{padding},
33+
dilation_{dilation},
34+
ceil_mode_{ceil_mode},
35+
device_index_{grad_input.device().index()} {}
36+
37+
virtual void operator()(const Tensor grad_output, const Tensor input,
38+
const std::vector<int64_t> kernel_size,
39+
const std::vector<int64_t> stride,
40+
const std::vector<int64_t> padding,
41+
const std::vector<int64_t> dilation,
42+
const bool ceil_mode, const Tensor indices,
43+
Tensor grad_input) const = 0;
44+
45+
protected:
46+
Tensor::Shape grad_output_shape_;
47+
48+
Tensor::Strides grad_output_strides_;
49+
50+
DataType grad_output_type_;
51+
52+
Tensor::Shape input_shape_;
53+
54+
Tensor::Strides input_strides_;
55+
56+
DataType input_type_;
57+
58+
Tensor::Shape indices_shape_;
59+
60+
Tensor::Strides indices_strides_;
61+
62+
DataType indices_type_;
63+
64+
Tensor::Shape grad_input_shape_;
65+
66+
Tensor::Strides grad_input_strides_;
67+
68+
DataType grad_input_type_;
69+
70+
std::vector<int64_t> kernel_size_{};
71+
72+
std::vector<int64_t> stride_{};
73+
74+
std::vector<int64_t> padding_{};
75+
76+
std::vector<int64_t> dilation_{};
77+
78+
bool ceil_mode_{};
79+
80+
int device_index_{0};
81+
};
82+
83+
} // namespace infini::ops
84+
85+
#endif

0 commit comments

Comments
 (0)