Skip to content

Commit bf4d338

Browse files
authored
feat: add max_pool3d_with_indices base (#330)
1 parent 788a2d9 commit bf4d338

1 file changed

Lines changed: 74 additions & 0 deletions

File tree

src/base/max_pool3d_with_indices.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef INFINI_OPS_BASE_MAX_POOL3D_WITH_INDICES_H_
2+
#define INFINI_OPS_BASE_MAX_POOL3D_WITH_INDICES_H_
3+
4+
#include "operator.h"
5+
6+
namespace infini::ops {
7+
8+
class MaxPool3dWithIndices : public Operator<MaxPool3dWithIndices> {
9+
public:
10+
MaxPool3dWithIndices(const Tensor input,
11+
const std::vector<int64_t> kernel_size,
12+
const std::vector<int64_t> stride,
13+
const std::vector<int64_t> padding,
14+
const std::vector<int64_t> dilation,
15+
const bool ceil_mode, Tensor out, Tensor indices)
16+
: input_shape_{input.shape()},
17+
input_strides_{input.strides()},
18+
input_type_{input.dtype()},
19+
out_shape_{out.shape()},
20+
out_strides_{out.strides()},
21+
out_type_{out.dtype()},
22+
indices_shape_{indices.shape()},
23+
indices_strides_{indices.strides()},
24+
indices_type_{indices.dtype()},
25+
kernel_size_{kernel_size},
26+
stride_{stride},
27+
padding_{padding},
28+
dilation_{dilation},
29+
ceil_mode_{ceil_mode},
30+
device_index_{out.device().index()} {}
31+
32+
virtual void operator()(const Tensor input,
33+
const std::vector<int64_t> kernel_size,
34+
const std::vector<int64_t> stride,
35+
const std::vector<int64_t> padding,
36+
const std::vector<int64_t> dilation,
37+
const bool ceil_mode, Tensor out,
38+
Tensor indices) const = 0;
39+
40+
protected:
41+
Tensor::Shape input_shape_;
42+
43+
Tensor::Strides input_strides_;
44+
45+
DataType input_type_;
46+
47+
Tensor::Shape out_shape_;
48+
49+
Tensor::Strides out_strides_;
50+
51+
DataType out_type_;
52+
53+
Tensor::Shape indices_shape_;
54+
55+
Tensor::Strides indices_strides_;
56+
57+
DataType indices_type_;
58+
59+
std::vector<int64_t> kernel_size_{};
60+
61+
std::vector<int64_t> stride_{};
62+
63+
std::vector<int64_t> padding_{};
64+
65+
std::vector<int64_t> dilation_{};
66+
67+
bool ceil_mode_{};
68+
69+
int device_index_{0};
70+
};
71+
72+
} // namespace infini::ops
73+
74+
#endif

0 commit comments

Comments
 (0)