forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinputs_portable.cpp
More file actions
83 lines (70 loc) · 2.44 KB
/
inputs_portable.cpp
File metadata and controls
83 lines (70 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* Copyright 2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/runner_util/inputs.h>
#include <algorithm>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/method_meta.h>
#include <executorch/runtime/platform/log.h>
using executorch::aten::Tensor;
using executorch::aten::TensorImpl;
using executorch::runtime::Error;
using executorch::runtime::Method;
using executorch::runtime::TensorInfo;
namespace executorch {
namespace extension {
namespace internal {
namespace {
/**
* Sets all elements of a tensor to 1.
*/
Error fill_ones(torch::executor::Tensor tensor) {
#define FILL_CASE(T, n) \
case (torch::executor::ScalarType::n): \
std::fill( \
tensor.mutable_data_ptr<T>(), \
tensor.mutable_data_ptr<T>() + tensor.numel(), \
T(1)); \
break;
switch (tensor.scalar_type()) {
ET_FORALL_REALHBBF16_TYPES(FILL_CASE)
default:
ET_LOG(Error, "Unsupported scalar type %d", (int)tensor.scalar_type());
return Error::InvalidArgument;
}
#undef FILL_CASE
return Error::Ok;
}
} // namespace
Error fill_and_set_input(
Method& method,
TensorInfo& tensor_meta,
size_t input_index,
void* data_ptr,
bool fill_tensor) {
TensorImpl impl = TensorImpl(
tensor_meta.scalar_type(),
/*dim=*/tensor_meta.sizes().size(),
// These const pointers will not be modified because we never resize this
// short-lived TensorImpl. It only exists so that set_input() can verify
// that the shape is correct; the Method manages its own sizes and
// dim_order arrays for the input.
const_cast<TensorImpl::SizesType*>(tensor_meta.sizes().data()),
data_ptr,
const_cast<TensorImpl::DimOrderType*>(tensor_meta.dim_order().data()));
Tensor t(&impl);
if (fill_tensor) {
ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t));
}
return method.set_input(t, input_index);
}
} // namespace internal
} // namespace extension
} // namespace executorch