Skip to content

Commit c338e85

Browse files
authored
[Cpp API Compatibility] Align SymInt with PyTorch (#78807)
1 parent 2c4c7f6 commit c338e85

9 files changed

Lines changed: 42 additions & 13 deletions

File tree

paddle/phi/api/include/compat/ATen/TensorIndexing.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ struct TensorIndex final {
9191

9292
TensorIndex(c10::SymInt integer)
9393
: integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
94+
TensorIndex(int64_t integer) : TensorIndex(c10::SymInt(integer)) {}
9495
TensorIndex(int integer) : TensorIndex(c10::SymInt(integer)) {}
9596

9697
template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>

paddle/phi/api/include/compat/ATen/core/TensorBase.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ class PADDLE_API TensorBase {
155155
}
156156

157157
c10::SymIntArrayRef sym_strides() const {
158-
return c10::SymIntArrayRef(strides());
158+
return c10::SymIntArrayRef(
159+
reinterpret_cast<const c10::SymInt*>(strides().data()),
160+
strides().size());
159161
}
160162

161163
int64_t size(int64_t dim) const {
@@ -173,7 +175,10 @@ class PADDLE_API TensorBase {
173175
return compat::_PD_PhiDDimToIntArrayRef(tensor_.dims());
174176
}
175177

176-
c10::SymIntArrayRef sym_sizes() const { return c10::SymIntArrayRef(sizes()); }
178+
c10::SymIntArrayRef sym_sizes() const {
179+
return c10::SymIntArrayRef(
180+
reinterpret_cast<const c10::SymInt*>(sizes().data()), sizes().size());
181+
}
177182

178183
int64_t numel() const { return tensor_.numel(); }
179184

paddle/phi/api/include/compat/ATen/ops/split.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ inline std::vector<at::Tensor> split(const at::Tensor& self,
5959
inline std::vector<at::Tensor> split_symint(const at::Tensor& self,
6060
c10::SymIntArrayRef split_sizes,
6161
int64_t dim = 0) {
62-
return split(self,
63-
at::IntArrayRef(static_cast<const int64_t*>(split_sizes.data()),
64-
split_sizes.size()),
65-
dim);
62+
return split(
63+
self,
64+
at::IntArrayRef(reinterpret_cast<const int64_t*>(split_sizes.data()),
65+
split_sizes.size()),
66+
dim);
6667
}
6768

6869
} // namespace at

paddle/phi/api/include/compat/ATen/ops/tensor_split.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ inline std::vector<at::Tensor> tensor_split_symint(const at::Tensor& self,
119119
int64_t dim = 0) {
120120
return tensor_split(
121121
self,
122-
at::IntArrayRef(static_cast<const int64_t*>(indices.data()),
122+
at::IntArrayRef(reinterpret_cast<const int64_t*>(indices.data()),
123123
indices.size()),
124124
dim);
125125
}

paddle/phi/api/include/compat/ATen/ops/unflatten.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ inline at::Tensor unflatten(const at::Tensor& self,
4040
inline at::Tensor unflatten_symint(const at::Tensor& self,
4141
const int64_t dim,
4242
c10::SymIntArrayRef sizes) {
43-
// SymIntArrayRef is the same as IntArrayRef in this implementation
44-
return unflatten(self, dim, sizes);
43+
return unflatten(
44+
self,
45+
dim,
46+
at::IntArrayRef(reinterpret_cast<const int64_t*>(sizes.data()),
47+
sizes.size()));
4548
}
4649

4750
} // namespace at

paddle/phi/api/include/compat/c10/core/SymInt.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,26 @@
1515
#pragma once
1616
#include <c10/util/accumulate.h>
1717
#include <cstdint>
18+
#include <optional>
1819

1920
namespace c10 {
20-
using SymInt = int64_t;
21+
22+
class SymInt {
23+
public:
24+
SymInt() : data_(0) {}
25+
/*implicit*/ SymInt(int64_t d) : data_(d) {} // NOLINT
26+
/*implicit*/ operator int64_t() const { return data_; }
27+
28+
int64_t guard_int(const char* file, int64_t line) const {
29+
(void)file;
30+
(void)line;
31+
return data_;
32+
}
33+
34+
std::optional<int64_t> maybe_as_int() const { return data_; }
35+
36+
private:
37+
int64_t data_;
38+
};
2139

2240
} // namespace c10

paddle/phi/api/include/compat/c10/core/SymIntArrayRef.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <c10/util/ArrayRef.h>
1919

2020
namespace c10 {
21-
using SymIntArrayRef = IntArrayRef; // SymIntArrayRef is same as ArrayRef
21+
using SymIntArrayRef = ArrayRef<SymInt>;
2222
} // namespace c10
2323

2424
namespace at {

paddle/phi/api/include/compat/c10/util/ArrayRef.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ class ArrayRef {
152152
std::vector<T> vec() const { return std::vector<T>(Data, Data + Length); }
153153

154154
const paddle::experimental::IntArray _PD_ToPaddleIntArray() const {
155-
return paddle::experimental::IntArray(Data, Length);
155+
return paddle::experimental::IntArray(
156+
reinterpret_cast<const int64_t*>(Data), Length);
156157
}
157158
};
158159

test/cpp/compat/ATen_flatten_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ TEST(TestUnflatten, UnflattenSymInt) {
146146

147147
// Unflatten dimension 1 using symint version
148148
// Note: Must keep the underlying data alive
149-
std::vector<int64_t> sizes_vec = {2, 3};
149+
std::vector<c10::SymInt> sizes_vec = {2, 3};
150150
c10::SymIntArrayRef sizes(sizes_vec);
151151
at::Tensor unflattened = tensor.unflatten_symint(1, sizes);
152152
ASSERT_EQ(unflattened.sizes(), c10::IntArrayRef({4, 2, 3, 8}));

0 commit comments

Comments
 (0)