forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNestedTensorImpl.h
More file actions
105 lines (93 loc) · 3.12 KB
/
NestedTensorImpl.h
File metadata and controls
105 lines (93 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <ATen/MemoryOverlap.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Metaprogramming.h>
namespace at {
namespace native {
struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
int64_t numel() const override {
TORCH_CHECK(
false, "numel is disabled. These methods are not virtual in fbcode.");
}
#endif
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
bool is_contiguous(at::MemoryFormat memory_format) const override {
TORCH_CHECK(
false,
"is_contiguous is disabled. These methods are not virtual in fbcode.");
}
#endif
// TODO: don't expose private implementation details like this; in
// particular, resizing this tensor will mess up our dim() and
// callers cannot fix it.
const Tensor& get_nested_size_tensor() const {
return nested_size_tensor_;
}
// Returns nullopt if the ith dimension is irregular. The ith dimension
// of a NestedTensor is regular if the unbound tensors match in
// size at the (i-1)th dimension.
c10::optional<int64_t> opt_size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
if (opt_sizes_[d] == -1) {
return c10::nullopt;
}
return opt_sizes_[d];
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
IntArrayRef sizes() const override {
TORCH_CHECK(
false,
"Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
return IntArrayRef();
}
#endif
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
IntArrayRef strides() const override {
TORCH_CHECK(
false,
"Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
return IntArrayRef();
}
#endif
const at::Tensor& get_buffer() const {
return buffer_;
}
protected:
const char* tensorimpl_type_name() const override;
private:
// Must be called after any changes to our dim() to sync the state
// to TensorImpl.
void refresh_dim();
at::Tensor buffer_;
const at::Tensor nested_size_tensor_;
// NOTE: -1 here means the size is missing
std::vector<int64_t> opt_sizes_;
};
inline NestedTensorImpl* get_nested_tensor_impl_or_null(const at::Tensor& tensor) {
if (tensor.is_nested()) {
return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
}
return nullptr;
}
inline NestedTensorImpl* get_nested_tensor_impl(
const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_nested(),
"get_nested_tensor_impl requires a NestedTensor.");
return static_cast<NestedTensorImpl*>(
tensor.unsafeGetTensorImpl());
}
// TODO: real implementation once we support strides.
inline bool nested_tensor_impl_is_contiguous(
const NestedTensorImpl* nt,
at::MemoryFormat memory_format = MemoryFormat::Contiguous) {
return memory_format == MemoryFormat::Contiguous;
}
} // namespace native
} // namespace at