forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNestedTensorImpl.cpp
More file actions
75 lines (69 loc) · 2.58 KB
/
NestedTensorImpl.cpp
File metadata and controls
75 lines (69 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
namespace at {
namespace native {
inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
if (sizes.dim() == 0) {
return std::vector<int64_t>();
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
std::vector<int64_t> result(1, sizes.sizes()[0]);
if (sizes.dim() > 0) {
size_t nested_dim = result.size();
int64_t* sizes_ptr = sizes.data_ptr<int64_t>();
result.resize(nested_dim + sizes.sizes()[1]);
int64_t sizes_size_0 = sizes.sizes()[0];
int64_t sizes_size_1 = sizes.sizes()[1];
for (const auto i : c10::irange(sizes_size_1)) {
result[nested_dim + i] = sizes_ptr[i];
}
for (const auto j : c10::irange(sizes_size_1)) {
for (const auto i : c10::irange(sizes_size_0)) {
if (result[nested_dim + j] &&
(result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) {
result[nested_dim + j] = -1;
}
}
}
}
return result;
}
NestedTensorImpl::NestedTensorImpl(
at::Tensor buffer,
at::Tensor nested_size_tensor)
: TensorImpl(
(c10::DispatchKeySet(DispatchKey::NestedTensor) |
c10::DispatchKeySet(buffer.is_cuda() ? BackendComponent::CUDABit : BackendComponent::CPUBit)),
buffer.dtype(),
buffer.device()),
buffer_(std::move(buffer)),
nested_size_tensor_(std::move(nested_size_tensor)),
opt_sizes_(construct_opt_sizes(nested_size_tensor_))
{
TORCH_WARN_ONCE(
"The PyTorch API of nested tensors is in prototype stage and will change "
"in the near future.");
TORCH_INTERNAL_ASSERT(buffer_.is_cuda() || buffer_.is_cpu(), "NestedTensorImpl buffer must be either CUDA or CPU but got ", buffer_);
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
int64_t size_dim = nested_size_tensor_.dim();
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
remove_autograd_key();
key_set_ =
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
refresh_dim();
set_sizes_customization_policy(CustomizableMethodPolicy::NotSupported);
}
void NestedTensorImpl::refresh_dim() {
const auto my_dim = nested_size_tensor_.dim() ? nested_size_tensor_.sizes()[1] + 1 : 1;
sizes_and_strides_.resize(my_dim);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim() == my_dim);
}
const char* NestedTensorImpl::tensorimpl_type_name() const {
return "NestedTensorImpl";
}
} // namespace native
} // namespace at