Skip to content

Commit 9d23455

Browse files
committed
feat: add searchsorted base
1 parent 832e048 commit 9d23455

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

src/base/searchsorted.h

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// AUTO-GENERATED by `scripts/generate_torch_ops.py` — DO NOT EDIT.
2+
#ifndef INFINI_OPS_BASE_SEARCHSORTED_H_
3+
#define INFINI_OPS_BASE_SEARCHSORTED_H_
4+
5+
#include "operator.h"
6+
7+
namespace infini::ops {
8+
9+
class Searchsorted : public Operator<Searchsorted> {
10+
public:
11+
Searchsorted(const Tensor sorted_sequence, const Tensor input, const bool out_int32, const bool right, Tensor out)
12+
: sorted_sequence_shape_{sorted_sequence.shape()},
13+
sorted_sequence_strides_{sorted_sequence.strides()},
14+
sorted_sequence_type_{sorted_sequence.dtype()},
15+
input_shape_{input.shape()},
16+
input_strides_{input.strides()},
17+
input_type_{input.dtype()},
18+
out_shape_{out.shape()},
19+
out_strides_{out.strides()},
20+
out_type_{out.dtype()},
21+
out_int32_{out_int32},
22+
right_{right},
23+
device_index_{out.device().index()} {}
24+
25+
Searchsorted(const Tensor sorted_sequence, const double input, const bool out_int32, const bool right, Tensor out)
26+
: sorted_sequence_shape_{sorted_sequence.shape()},
27+
sorted_sequence_strides_{sorted_sequence.strides()},
28+
sorted_sequence_type_{sorted_sequence.dtype()},
29+
out_shape_{out.shape()},
30+
out_strides_{out.strides()},
31+
out_type_{out.dtype()},
32+
out_int32_{out_int32},
33+
right_{right},
34+
input_{input},
35+
device_index_{out.device().index()} {}
36+
37+
virtual void operator()(const Tensor sorted_sequence, const Tensor input, const bool out_int32, const bool right, Tensor out) const = 0;
38+
39+
virtual void operator()(const Tensor sorted_sequence, const double input, const bool out_int32, const bool right, Tensor out) const = 0;
40+
41+
protected:
42+
Tensor::Shape sorted_sequence_shape_;
43+
44+
Tensor::Strides sorted_sequence_strides_;
45+
46+
DataType sorted_sequence_type_;
47+
48+
Tensor::Shape input_shape_;
49+
50+
Tensor::Strides input_strides_;
51+
52+
DataType input_type_;
53+
54+
Tensor::Shape out_shape_;
55+
56+
Tensor::Strides out_strides_;
57+
58+
DataType out_type_;
59+
60+
bool out_int32_{};
61+
62+
bool right_{};
63+
64+
double input_{};
65+
66+
int device_index_{0};
67+
};
68+
69+
} // namespace infini::ops
70+
71+
#endif

0 commit comments

Comments
 (0)