forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSymIntArrayRef.h
More file actions
180 lines (147 loc) · 5.58 KB
/
SymIntArrayRef.h
File metadata and controls
180 lines (147 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
// This file defines `SymIntArrayRef` which serves as the view onto
// std::vector<SymInt> This class is conceptually and mostly functionally
// equivalent to ArrayRef<SymInt> However, ArrayRef<SymInt> can't be used
// directly as it introduces ambiguity in the following cases: a.expand({1, 2,
// 3}) matches two overloads: `at::Tensor Tensor::expand(c10::SymIntArrayRef
// size, bool implicit)` `at::Tensor Tensor::expand(at::IntArrayRef size, bool
// implicit)` Introducing `SymIntArrayRef` allows to have a finer-grained
// control over which overload will be used
#pragma once
#include <ATen/core/SymInt.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <array>
#include <initializer_list>
#include <iterator>
#include <vector>
namespace c10 {
/// SymIntArrayRef - Represent a constant reference to an array (0 or more
/// elements consecutively in memory), i.e. a start pointer and a length. It
/// allows various APIs to take consecutive elements easily and conveniently.
///
/// This class does not own the underlying data, it is expected to be used in
/// situations where the data resides in some other buffer, whose lifetime
/// extends past that of the SymIntArrayRef. For this reason, it is not in
/// general safe to store an SymIntArrayRef.
///
/// This is intended to be trivially copyable, so it should be passed by
/// value.
class SymIntArrayRef final {
public:
using iterator = const c10::SymInt*;
using const_iterator = const c10::SymInt*;
using size_type = size_t;
using value_type = c10::SymInt;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
ArrayRef<c10::SymInt> wrapped_symint_array_ref;
public:
/// @name Constructors
/// @{
/// Construct an empty SymIntArrayRef.
/* implicit */ constexpr SymIntArrayRef() {}
/* implicit */ SymIntArrayRef(const std::vector<c10::SymInt>& Vec)
: wrapped_symint_array_ref(Vec) {}
/// Construct an SymIntArrayRef from a pointer and length.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* data,
size_t length)
: wrapped_symint_array_ref(data, length) {}
/// Construct an SymIntArrayRef from a range.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef(
const c10::SymInt* begin,
const c10::SymInt* end)
: wrapped_symint_array_ref(begin, end) {}
/// Construct an SymIntArrayRef from a C array.
template <size_t N>
/* implicit */ constexpr SymIntArrayRef(const c10::SymInt (&Arr)[N])
: wrapped_symint_array_ref(Arr) {}
/// @}
/// @name Simple Operations
/// @{
constexpr iterator begin() const {
return wrapped_symint_array_ref.begin();
}
constexpr iterator end() const {
return wrapped_symint_array_ref.end();
}
// These are actually the same as iterator, since SymIntArrayRef only
// gives you const iterators.
constexpr const_iterator cbegin() const {
return wrapped_symint_array_ref.cbegin();
}
constexpr const_iterator cend() const {
return wrapped_symint_array_ref.cend();
}
/// empty - Check if the array is empty.
constexpr bool empty() const {
return size() == 0;
}
constexpr const c10::SymInt* data() const {
return wrapped_symint_array_ref.data();
}
/// size - Get the array size.
constexpr size_t size() const {
return wrapped_symint_array_ref.size();
}
/// front - Get the first element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& front() const {
return wrapped_symint_array_ref.front();
}
/// back - Get the last element.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& back() const {
return wrapped_symint_array_ref.back();
}
/// equals - Check for element-wise equality.
constexpr bool equals(SymIntArrayRef RHS) const {
return this->wrapped_symint_array_ref.equals(RHS.wrapped_symint_array_ref);
}
/// slice(n, m) - Take M elements of the array starting at element N
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef
slice(size_t N, size_t M) const {
return SymIntArrayRef(wrapped_symint_array_ref.data() + N, M);
}
/// slice(n) - Chop off the first N elements of the array.
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA SymIntArrayRef slice(size_t N) const {
return slice(N, size() - N);
}
/// @}
/// @name Operator Overloads
/// @{
constexpr const c10::SymInt& operator[](size_t Index) const {
return wrapped_symint_array_ref[Index];
}
/// Vector compatibility
C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const c10::SymInt& at(size_t Index) const {
return wrapped_symint_array_ref.at(Index);
}
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
type&
operator=(U&& Temporary) = delete;
/// Disallow accidental assignment from a temporary.
///
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
typename std::enable_if<std::is_same<U, c10::SymInt>::value, SymIntArrayRef>::
type&
operator=(std::initializer_list<U>) = delete;
/// @}
/// @name Expensive Operations
/// @{
std::vector<c10::SymInt> vec() const {
return wrapped_symint_array_ref.vec();
}
friend std::ostream& operator<<(
std::ostream& out,
const SymIntArrayRef& list);
/// @}
};
TORCH_API at::IntArrayRef expectIntArrayRef(c10::SymIntArrayRef ar);
std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list);
} // namespace c10