forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmake_boxed_from_unboxed_functor.h
More file actions
260 lines (230 loc) · 8.55 KB
/
make_boxed_from_unboxed_functor.h
File metadata and controls
260 lines (230 loc) · 8.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
//===----------------------------------------------------------------------===//
/// \file extension/kernel_util/make_boxed_from_unboxed_functor.h
/// Defines a template that can be used to create a boxed version of an unboxed
/// functor.
/// Example usage:
/// ```
/// Tensor&
/// my_op(KernelRuntimeContext& ctx, const Tensor& self, const Tensor& other,
/// Tensor& out)
/// {
/// // ...
/// return out;
/// }
///
/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op",
/// EXECUTORCH_FN(my_op));
/// static auto res = register_kernels({my_kernel});
/// ```
/// Or simply:
/// ```
/// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op);
/// ```
///
/// The trick here is to convert each EValue to inferred argument type. This
/// uses a lot of C++17 features.
//===----------------------------------------------------------------------===//
#pragma once
#include <executorch/extension/kernel_util/meta_programming.h>
#include <executorch/extension/kernel_util/type_list.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
#include <executorch/runtime/kernel/operator_registry.h>
#include <cstdlib>
#include <memory>
#include <type_traits>
#include <typeinfo>
namespace executorch {
namespace runtime {
class KernelRuntimeContext; // Forward declaration
} // namespace runtime
} // namespace executorch
namespace executorch {
namespace extension {
// This extension has a lot of generic internal names like "size"; use a unique
// internal namespace to avoid conflicts with other extensions.
namespace kernel_util_internal {
// Template trait to check if a type is a non-const tensor
template <class T>
struct is_nonconst_tensor : std::false_type {};
template <>
struct is_nonconst_tensor<executorch::aten::Tensor&> : std::true_type {};
// Template trait to check if a type is a non-const tensor
// Count non-const tensors in a typelist
template <class TypeList>
struct count_nonconst_tensors;
template <>
struct count_nonconst_tensors<typelist<>> {
static constexpr size_t value = 0;
};
template <class T>
struct count_nonconst_tensors<typelist<T>> {
static constexpr size_t value = 0;
};
template <>
struct count_nonconst_tensors<typelist<executorch::aten::Tensor&>> {
static constexpr size_t value = 1;
};
template <class Head, class... Tail>
struct count_nonconst_tensors<typelist<Head, Tail...>> {
private:
static constexpr size_t tail_tensor_count =
count_nonconst_tensors<typelist<Tail...>>::value;
static constexpr size_t tail_args_count = sizeof...(Tail);
static constexpr bool is_head_a_tensor = is_nonconst_tensor<Head>::value;
static constexpr bool all_tail_args_are_tensor =
tail_tensor_count == tail_args_count;
public:
static constexpr size_t value = (is_head_a_tensor && all_tail_args_are_tensor)
? tail_tensor_count + 1
: tail_tensor_count;
};
template <class T>
struct decay_if_not_tensor final {
using type = std::decay_t<T>;
};
template <>
struct decay_if_not_tensor<executorch::aten::Tensor&> final {
using type = executorch::aten::Tensor&;
};
template <>
struct decay_if_not_tensor<const executorch::aten::Tensor&> final {
using type = const executorch::aten::Tensor&;
};
template <class T>
struct evalue_to_arg final {
static T call(executorch::runtime::EValue& v) {
return std::move(v).to<T>();
}
};
template <>
struct evalue_to_arg<executorch::aten::Tensor&> final {
static executorch::aten::Tensor& call(executorch::runtime::EValue& v) {
return v.toTensor();
}
};
template <>
struct evalue_to_arg<const executorch::aten::Tensor&> final {
static const executorch::aten::Tensor& call(executorch::runtime::EValue& v) {
return v.toTensor();
}
};
template <class T>
struct evalue_to_arg<std::optional<T>> final {
static std::optional<T> call(executorch::runtime::EValue& v) {
return v.toOptional<T>();
}
};
template <class T>
struct evalue_to_arg<executorch::aten::ArrayRef<std::optional<T>>> final {
static executorch::aten::ArrayRef<std::optional<T>> call(
executorch::runtime::EValue& v) {
return v.toListOptionalTensor();
}
};
template <
class Functor,
size_t nonconst_tensors_to_log,
size_t... evalue_arg_indices,
typename... ArgTypes>
void call_functor_with_args_from_stack(
executorch::runtime::KernelRuntimeContext& ctx,
executorch::runtime::Span<executorch::runtime::EValue*> stack,
std::index_sequence<evalue_arg_indices...>,
typelist<ArgTypes...>*) {
executorch::runtime::internal::EventTracerProfileOpScope
event_tracer_op_scope(ctx.internal_event_tracer(), Functor::func_name_);
EXECUTORCH_SCOPE_PROF(Functor::func_name_);
(*Functor::func_ptr())(
ctx,
evalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type>::call(
*stack[evalue_arg_indices])...);
constexpr size_t num_inputs =
std::index_sequence<evalue_arg_indices...>::size();
for (size_t i = num_inputs - nonconst_tensors_to_log; i < num_inputs; ++i) {
executorch::runtime::internal::event_tracer_log_evalue(
ctx.internal_event_tracer(), *stack[i]);
}
}
} // namespace kernel_util_internal
/**
* WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that
* takes EValues as input and returns void. The wrapped functor will unbox all
* inputs and forward them to unboxed kernel.
*/
template <class FuncType>
struct WrapUnboxedIntoFunctor {
static_assert(
kernel_util_internal::is_compile_time_function_pointer<FuncType>::value,
"Can't handle function other than EXECUTORCH_FN");
using TrueType = typename FuncType::FuncType;
using ReturnType = typename kernel_util_internal::infer_function_traits_t<
TrueType>::return_type;
using ArgsType = typename kernel_util_internal::infer_function_traits_t<
TrueType>::parameter_types;
// check if the first argument is KernelRuntimeContext, if so, remove it
static constexpr bool first_arg_is_context = std::is_same<
::executorch::runtime::KernelRuntimeContext,
std::remove_reference_t<
kernel_util_internal::head_with_default_t<void, ArgsType>>>::value;
using ContextRemovedArgsType = std::conditional_t<
first_arg_is_context,
kernel_util_internal::drop_if_nonempty_t<ArgsType, 1>,
ArgsType>;
static void call(
::executorch::runtime::KernelRuntimeContext& ctx,
executorch::runtime::Span<executorch::runtime::EValue*> stack) {
constexpr size_t num_inputs =
kernel_util_internal::size<ContextRemovedArgsType>::value;
constexpr size_t num_nonconst_tensors =
kernel_util_internal::count_nonconst_tensors<
ContextRemovedArgsType>::value;
static_assert(num_nonconst_tensors == 1, "Invalid number of inputs");
return kernel_util_internal::
call_functor_with_args_from_stack<FuncType, num_nonconst_tensors>(
ctx,
stack,
std::make_index_sequence<num_inputs>(),
static_cast<ContextRemovedArgsType*>(nullptr));
}
};
template <typename FuncType>
static executorch::runtime::Kernel make_boxed_kernel(
const char* name,
FuncType) {
return executorch::runtime::Kernel(
name, WrapUnboxedIntoFunctor<FuncType>::call);
}
} // namespace extension
} // namespace executorch
// Inspired from C10_CONCATENATE
#define ET_CONCATENATE_IMPL(s1, s2) s1##s2
#define ET_CONCATENATE(s1, s2) ET_CONCATENATE_IMPL(s1, s2)
#define ET_UID __LINE__
#define EXECUTORCH_LIBRARY(ns, op_name, func) \
_EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, ET_UID)
#define _EXECUTORCH_LIBRARY_IMPL(ns, op_name, func, uid) \
static constexpr const char ET_CONCATENATE(name_of_op_, uid)[] = \
#ns "::" op_name; \
static auto ET_CONCATENATE(res_##ns##_, uid) = \
::executorch::runtime::register_kernel( \
::executorch::extension::make_boxed_kernel( \
#ns "::" op_name, \
EXECUTORCH_FN(func, ET_CONCATENATE(name_of_op_, uid))))
namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::make_boxed_kernel;
using ::executorch::extension::WrapUnboxedIntoFunctor;
} // namespace executor
} // namespace torch