forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor_layout.cpp
More file actions
66 lines (60 loc) · 1.88 KB
/
tensor_layout.cpp
File metadata and controls
66 lines (60 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <c10/util/irange.h>
#include <c10/util/safe_numerics.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/core/tensor_layout.h>
namespace executorch {
namespace ET_RUNTIME_NAMESPACE {
namespace {
Result<size_t> calculate_nbytes(
const Span<const int32_t>& sizes,
const executorch::aten::ScalarType& scalar_type) {
size_t n = 1;
for (const auto i : c10::irange(sizes.size())) {
if (sizes[i] < 0) {
return Error::InvalidArgument;
}
size_t next = 0;
if (c10::mul_overflows(n, static_cast<size_t>(sizes[i]), &next)) {
return Error::InvalidArgument;
}
n = next;
}
// Use the full namespace to disambiguate from c10::elementSize.
const size_t elem_size =
static_cast<size_t>(executorch::runtime::elementSize(scalar_type));
size_t total = 0;
if (c10::mul_overflows(n, elem_size, &total)) {
return Error::InvalidArgument;
}
return total;
}
} // namespace
Result<const TensorLayout> TensorLayout::create(
Span<const int32_t> sizes,
Span<const uint8_t> dim_order,
executorch::aten::ScalarType scalar_type) {
auto nbytes = calculate_nbytes(sizes, scalar_type);
if (!nbytes.ok()) {
return nbytes.error();
}
if (dim_order.size() != sizes.size()) {
return Error::InvalidArgument;
}
for (const auto i : c10::irange(dim_order.size())) {
if (dim_order[i] >= sizes.size()) {
return Error::InvalidArgument;
}
}
return TensorLayout(sizes, dim_order, scalar_type, nbytes.get());
}
} // namespace ET_RUNTIME_NAMESPACE
} // namespace executorch