Skip to content

Commit 333b4bb

Browse files
author
Github Executorch
committed
Codegen: emit EValue::tryTo<T>() in boxed kernel wrappers
Before this change, every generated boxed kernel wrapper unpacked its EValue stack args via the abort-on-mismatch accessors (.to<T>(), .toOptional<T>(), .toTensorList(), .toIntList(), .toDoubleList(), .toBoolList(), .toListOptionalTensor()). A malformed PTE whose KernelCall arg_indices point at wrong-tagged EValues would therefore crash the runtime with ET_CHECK_MSG inside the generated wrapper, before the underlying native kernel sees the inputs. Emit the Result-returning tryTo* counterparts instead, with a single shared error-propagation pattern: auto x_res = ev.tryTo<T>(); if (!x_res.ok()) { context.fail(x_res.error()); return; } T x = x_res.get(); The kernel entry function is already void-returning and takes a KernelRuntimeContext&, which has a documented fail() channel for precisely this case — no ABI change, no signature change, no impact on hand-written kernels registered directly against OpFunction. Coverage: _gen_code_base_type, _gen_code_optional_type, and each typed branch of _gen_code_list_type (Tensor, int/SymInt, float, bool with ATen/lean variants, and ListOptional<Tensor> with ATen/lean variants). Depends on the EValue::tryTo<T>() API added earlier in the stack. Authored-with: Claude ghstack-source-id: 4b20a4f ghstack-comment-id: 4292676210 Pull-Request: #19041
1 parent de8ce55 commit 333b4bb

2 files changed

Lines changed: 109 additions & 58 deletions

File tree

codegen/api/unboxing.py

Lines changed: 81 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Type,
1414
)
1515

16-
1716
if TYPE_CHECKING:
1817
from collections.abc import Sequence
1918

@@ -34,18 +33,20 @@ class Unboxing:
3433
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
3534
A sample generated code:
3635
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
37-
void mul_out(Span<EValue*> stack) {
36+
void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
3837
EValue& self = *stack[0];
39-
EValue& other = *stack[1];
40-
EValue& out = *stack[2];
41-
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
42-
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
43-
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
44-
38+
// ... other args ...
39+
auto self_base_res = self.tryTo<torch::executor::Tensor>();
40+
if (!self_base_res.ok()) {
41+
::executorch::runtime::internal::kernel_arg_fail(
42+
context, self_base_res.error(), __func__, "self",
43+
static_cast<uint8_t>(self.tag));
44+
return;
45+
}
46+
const torch::executor::Tensor& self_base = self_base_res.get();
47+
// ... other unpacks ...
4548
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
4649
torch::executor::mul_outf(self_base, other_base, out_base);
47-
48-
4950
}
5051
"""
5152

@@ -115,8 +116,18 @@ def argumenttype_evalue_convert(
115116
def _gen_code_base_type(
116117
self, arg_name: str, out_name: str, ctype: CType
117118
) -> tuple[list[str], list[str]]:
119+
# Use tryTo<T>(). On error, call the kernel_arg_fail(..), which
120+
# logs the error and sets context.fail(). Then return.
121+
res_name = f"{out_name}_res"
118122
return [
119-
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
123+
f"auto {res_name} = {arg_name}.tryTo<{ctype.cpp_type(strip_ref=True)}>();",
124+
f"if (!{res_name}.ok()) {{",
125+
" ::executorch::runtime::internal::kernel_arg_fail(",
126+
f' context, {res_name}.error(), __func__, "{arg_name}",',
127+
f" static_cast<uint8_t>({arg_name}.tag));",
128+
" return;",
129+
"}",
130+
f"{ctype.cpp_type()} {out_name} = {res_name}.get();",
120131
], []
121132

122133
def _gen_code_optional_type(
@@ -126,12 +137,20 @@ def _gen_code_optional_type(
126137
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
127138
t.elem, in_name
128139
)
140+
# Use tryToOptional<T>(). On error, call the kernel_arg_fail(..), which
141+
# logs the error and sets context.fail(). Then return.
142+
opt_res_name = f"{out_name}_res"
129143
return (
130144
f"""
131-
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
132-
""".split(
133-
"\n"
134-
),
145+
auto {opt_res_name} = {arg_name}.tryToOptional<{base_type.cpp_type(strip_ref=True)}>();
146+
if (!{opt_res_name}.ok()) {{
147+
::executorch::runtime::internal::kernel_arg_fail(
148+
context, {opt_res_name}.error(), __func__, "{arg_name}",
149+
static_cast<uint8_t>({arg_name}.tag));
150+
return;
151+
}}
152+
auto {out_name} = std::move({opt_res_name}.get());
153+
""".split("\n"),
135154
decl,
136155
)
137156

@@ -145,88 +164,92 @@ def _gen_code_list_type(
145164
t.elem, elem_name
146165
)
147166

148-
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
149-
code.extend(
150-
f"""
151-
auto {out_name} = {arg_name}.toTensorList();
152-
""".split(
153-
"\n"
154-
)
167+
# Each branch uses the Result-returning tryToXList() accessor and
168+
# routes errors through kernel_arg_fail.
169+
res_name_list = f"{out_name}_res"
170+
171+
def _fail_block(res: str) -> str:
172+
return (
173+
f"if (!{res}.ok()) {{\n"
174+
f" ::executorch::runtime::internal::kernel_arg_fail(\n"
175+
f' context, {res}.error(), __func__, "{arg_name}",\n'
176+
f" static_cast<uint8_t>({arg_name}.tag));\n"
177+
f" return;\n"
178+
f" }}"
155179
)
180+
181+
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
182+
code.extend(f"""
183+
auto {res_name_list} = {arg_name}.tryToTensorList();
184+
{_fail_block(res_name_list)}
185+
auto {out_name} = {res_name_list}.get();
186+
""".split("\n"))
156187
elif isinstance(t.elem, BaseType) and (
157188
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
158189
):
159-
code.extend(
160-
f"""
161-
auto {out_name} = {arg_name}.toIntList();
162-
""".split(
163-
"\n"
164-
)
165-
)
190+
code.extend(f"""
191+
auto {res_name_list} = {arg_name}.tryToIntList();
192+
{_fail_block(res_name_list)}
193+
auto {out_name} = {res_name_list}.get();
194+
""".split("\n"))
166195
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
167-
code.extend(
168-
f"""
169-
auto {out_name} = {arg_name}.toDoubleList();
170-
""".split(
171-
"\n"
172-
)
173-
)
196+
code.extend(f"""
197+
auto {res_name_list} = {arg_name}.tryToDoubleList();
198+
{_fail_block(res_name_list)}
199+
auto {out_name} = {res_name_list}.get();
200+
""".split("\n"))
174201
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
175202
# handle list type with size, e.g., bool[4]
176-
code.extend(
177-
f"""
203+
code.extend(f"""
178204
#ifdef USE_ATEN_LIB
179205
std::array<bool, {t.size}> {out_name};
180-
auto {in_name} = {arg_name}.toBoolList();
206+
auto {in_name}_res = {arg_name}.tryToBoolList();
207+
{_fail_block(in_name + "_res")}
208+
auto {in_name} = {in_name}_res.get();
181209
size_t _i = 0;
182210
for (auto {elem_name}: {in_name}) {{
183211
{out_name}[_i++] = {elem_name};
184212
}}
185213
#else
186-
auto {out_name} = {arg_name}.toBoolList();
214+
auto {res_name_list} = {arg_name}.tryToBoolList();
215+
{_fail_block(res_name_list)}
216+
auto {out_name} = {res_name_list}.get();
187217
#endif
188-
""".split(
189-
"\n"
190-
)
191-
)
218+
""".split("\n"))
192219
# pytorch codegen:
193220
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
194221
elif (
195222
isinstance(t.elem, OptionalType)
196223
and isinstance(t.elem.elem, BaseType)
197224
and t.elem.elem.name == BaseTy.Tensor
198225
):
199-
code.extend(
200-
f"""
226+
code.extend(f"""
201227
#ifdef USE_ATEN_LIB
202-
auto {in_name} = {arg_name}.toListOptionalTensor();
228+
auto {in_name}_res = {arg_name}.tryToListOptionalTensor();
229+
{_fail_block(in_name + "_res")}
230+
auto {in_name} = {in_name}_res.get();
203231
c10::List<::std::optional<at::Tensor>> {out_name};
204232
for (auto {elem_name}: {in_name}) {{
205233
{out_name}.push_back({elem_name});
206234
}}
207235
#else
208-
auto {out_name} = {arg_name}.toListOptionalTensor();
236+
auto {res_name_list} = {arg_name}.tryToListOptionalTensor();
237+
{_fail_block(res_name_list)}
238+
auto {out_name} = {res_name_list}.get();
209239
#endif
210-
""".split(
211-
"\n"
212-
)
213-
)
240+
""".split("\n"))
214241
else:
215242
# use ArrayRef as default.
216243
vec_name = arg_name + "_vec"
217244
# need to bring vector instantiation out of scope so that ArrayRef has valid data
218245
decl.append(
219246
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
220247
)
221-
code.extend(
222-
f"""
248+
code.extend(f"""
223249
for (EValue {elem_name}: {in_name}) {{
224250
{connector.join(res_code)}
225251
{vec_name}.push_back({res_name});
226252
}}
227253
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
228-
""".split(
229-
"\n"
230-
)
231-
)
254+
""".split("\n"))
232255
return code, decl

runtime/kernel/kernel_runtime_context.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99
#pragma once
1010

11+
#include <cstdint>
12+
1113
#include <executorch/runtime/core/error.h>
1214
#include <executorch/runtime/core/event_tracer_hooks.h>
1315
#include <executorch/runtime/core/memory_allocator.h>
1416
#include <executorch/runtime/core/result.h>
1517
#include <executorch/runtime/platform/compiler.h>
18+
#include <executorch/runtime/platform/log.h>
1619

1720
namespace executorch {
1821
namespace ET_RUNTIME_NAMESPACE {
@@ -107,6 +110,31 @@ class KernelRuntimeContext {
107110
Error failure_state_ = Error::Ok;
108111
};
109112

113+
namespace internal {
114+
115+
// Cold path for codegen-emitted boxed kernel wrappers. Logs the error and
116+
// sets the kernel's failure state when an EValue arg unpack fails. The
117+
// wrapper must still `return` after calling this.
118+
#if defined(__GNUC__) || defined(__clang__)
119+
[[gnu::cold]]
120+
#endif
121+
inline void kernel_arg_fail(
122+
KernelRuntimeContext& context,
123+
Error error,
124+
const char* kernel_name,
125+
const char* arg_name,
126+
uint8_t actual_tag) {
127+
ET_LOG(
128+
Error,
129+
"%s: arg '%s' has unexpected EValue tag %u",
130+
kernel_name,
131+
arg_name,
132+
static_cast<unsigned>(actual_tag));
133+
context.fail(error);
134+
}
135+
136+
} // namespace internal
137+
110138
} // namespace ET_RUNTIME_NAMESPACE
111139
} // namespace executorch
112140

0 commit comments

Comments
 (0)