Skip to content

Commit 34fe334

Browse files
author
Github Executorch
committed
Codegen: emit EValue::tryTo<T>() in boxed kernel wrappers
Before this change, every generated boxed kernel wrapper unpacked its EValue stack args via the abort-on-mismatch accessors (.to<T>(), .toOptional<T>(), .toTensorList(), .toIntList(), .toDoubleList(), .toBoolList(), .toListOptionalTensor()). A malformed PTE whose KernelCall arg_indices point at wrong-tagged EValues would therefore crash the runtime with ET_CHECK_MSG inside the generated wrapper, before the underlying native kernel sees the inputs. Emit the Result-returning tryTo* counterparts instead, with a single shared error-propagation pattern: auto x_res = ev.tryTo<T>(); if (!x_res.ok()) { context.fail(x_res.error()); return; } T x = x_res.get(); The kernel entry function is already void-returning and takes a KernelRuntimeContext&, which has a documented fail() channel for precisely this case — no ABI change, no signature change, no impact on hand-written kernels registered directly against OpFunction. Coverage: _gen_code_base_type, _gen_code_optional_type, and each typed branch of _gen_code_list_type (Tensor, int/SymInt, float, bool with ATen/lean variants, and ListOptional<Tensor> with ATen/lean variants). Depends on the EValue::tryTo<T>() API added earlier in the stack. Authored-with: Claude ghstack-source-id: 1b6d5a0 ghstack-comment-id: 4292676210 Pull-Request: #19041
1 parent 0e8c07c commit 34fe334

2 files changed

Lines changed: 113 additions & 59 deletions

File tree

codegen/api/unboxing.py

Lines changed: 85 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Type,
1414
)
1515

16-
1716
if TYPE_CHECKING:
1817
from collections.abc import Sequence
1918

@@ -32,20 +31,22 @@ def name(f: NativeFunction) -> str:
3231
class Unboxing:
3332
"""
3433
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
35-
A sample generated code:
34+
A sample generated code (abbreviated to one arg for readability):
3635
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
37-
void mul_out(Span<EValue*> stack) {
36+
void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
3837
EValue& self = *stack[0];
39-
EValue& other = *stack[1];
40-
EValue& out = *stack[2];
41-
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
42-
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
43-
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
44-
38+
// ... other args ...
39+
auto self_base_res = self.tryTo<torch::executor::Tensor>();
40+
if (!self_base_res.ok()) {
41+
::executorch::runtime::internal::kernel_arg_fail(
42+
context, self_base_res.error(), __func__, "self",
43+
static_cast<uint8_t>(self.tag));
44+
return;
45+
}
46+
const torch::executor::Tensor& self_base = self_base_res.get();
47+
// ... other unpacks ...
4548
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
4649
torch::executor::mul_outf(self_base, other_base, out_base);
47-
48-
4950
}
5051
"""
5152

@@ -115,8 +116,19 @@ def argumenttype_evalue_convert(
115116
def _gen_code_base_type(
116117
self, arg_name: str, out_name: str, ctype: CType
117118
) -> tuple[list[str], list[str]]:
119+
# Use tryTo<T>() with a shared cold fail helper so every wrapper
120+
# logs a consistent diagnostic and propagates the error via
121+
# KernelRuntimeContext::fail() rather than aborting.
122+
res_name = f"{out_name}_res"
118123
return [
119-
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
124+
f"auto {res_name} = {arg_name}.tryTo<{ctype.cpp_type(strip_ref=True)}>();",
125+
f"if (!{res_name}.ok()) {{",
126+
" ::executorch::runtime::internal::kernel_arg_fail(",
127+
f' context, {res_name}.error(), __func__, "{arg_name}",',
128+
f" static_cast<uint8_t>({arg_name}.tag));",
129+
" return;",
130+
"}",
131+
f"{ctype.cpp_type()} {out_name} = {res_name}.get();",
120132
], []
121133

122134
def _gen_code_optional_type(
@@ -126,12 +138,20 @@ def _gen_code_optional_type(
126138
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
127139
t.elem, in_name
128140
)
141+
# Use tryToOptional<T>() with the shared fail helper (see
142+
# _gen_code_base_type).
143+
opt_res_name = f"{out_name}_res"
129144
return (
130145
f"""
131-
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
132-
""".split(
133-
"\n"
134-
),
146+
auto {opt_res_name} = {arg_name}.tryToOptional<{base_type.cpp_type(strip_ref=True)}>();
147+
if (!{opt_res_name}.ok()) {{
148+
::executorch::runtime::internal::kernel_arg_fail(
149+
context, {opt_res_name}.error(), __func__, "{arg_name}",
150+
static_cast<uint8_t>({arg_name}.tag));
151+
return;
152+
}}
153+
auto {out_name} = std::move({opt_res_name}.get());
154+
""".split("\n"),
135155
decl,
136156
)
137157

@@ -145,88 +165,94 @@ def _gen_code_list_type(
145165
t.elem, elem_name
146166
)
147167

148-
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
149-
code.extend(
150-
f"""
151-
auto {out_name} = {arg_name}.toTensorList();
152-
""".split(
153-
"\n"
154-
)
168+
# Each branch uses the Result-returning tryToXList() accessor and
169+
# routes errors through the shared kernel_arg_fail helper; see
170+
# _gen_code_base_type for the rationale.
171+
res_name_list = f"{out_name}_res"
172+
173+
def _fail_block(res: str) -> str:
174+
# Cold fail path: log + context.fail() via the shared helper.
175+
return (
176+
f"if (!{res}.ok()) {{\n"
177+
f" ::executorch::runtime::internal::kernel_arg_fail(\n"
178+
f' context, {res}.error(), __func__, "{arg_name}",\n'
179+
f" static_cast<uint8_t>({arg_name}.tag));\n"
180+
f" return;\n"
181+
f" }}"
155182
)
183+
184+
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
185+
code.extend(f"""
186+
auto {res_name_list} = {arg_name}.tryToTensorList();
187+
{_fail_block(res_name_list)}
188+
auto {out_name} = {res_name_list}.get();
189+
""".split("\n"))
156190
elif isinstance(t.elem, BaseType) and (
157191
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
158192
):
159-
code.extend(
160-
f"""
161-
auto {out_name} = {arg_name}.toIntList();
162-
""".split(
163-
"\n"
164-
)
165-
)
193+
code.extend(f"""
194+
auto {res_name_list} = {arg_name}.tryToIntList();
195+
{_fail_block(res_name_list)}
196+
auto {out_name} = {res_name_list}.get();
197+
""".split("\n"))
166198
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
167-
code.extend(
168-
f"""
169-
auto {out_name} = {arg_name}.toDoubleList();
170-
""".split(
171-
"\n"
172-
)
173-
)
199+
code.extend(f"""
200+
auto {res_name_list} = {arg_name}.tryToDoubleList();
201+
{_fail_block(res_name_list)}
202+
auto {out_name} = {res_name_list}.get();
203+
""".split("\n"))
174204
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
175205
# handle list type with size, e.g., bool[4]
176-
code.extend(
177-
f"""
206+
code.extend(f"""
178207
#ifdef USE_ATEN_LIB
179208
std::array<bool, {t.size}> {out_name};
180-
auto {in_name} = {arg_name}.toBoolList();
209+
auto {in_name}_res = {arg_name}.tryToBoolList();
210+
{_fail_block(in_name + "_res")}
211+
auto {in_name} = {in_name}_res.get();
181212
size_t _i = 0;
182213
for (auto {elem_name}: {in_name}) {{
183214
{out_name}[_i++] = {elem_name};
184215
}}
185216
#else
186-
auto {out_name} = {arg_name}.toBoolList();
217+
auto {res_name_list} = {arg_name}.tryToBoolList();
218+
{_fail_block(res_name_list)}
219+
auto {out_name} = {res_name_list}.get();
187220
#endif
188-
""".split(
189-
"\n"
190-
)
191-
)
221+
""".split("\n"))
192222
# pytorch codegen:
193223
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
194224
elif (
195225
isinstance(t.elem, OptionalType)
196226
and isinstance(t.elem.elem, BaseType)
197227
and t.elem.elem.name == BaseTy.Tensor
198228
):
199-
code.extend(
200-
f"""
229+
code.extend(f"""
201230
#ifdef USE_ATEN_LIB
202-
auto {in_name} = {arg_name}.toListOptionalTensor();
231+
auto {in_name}_res = {arg_name}.tryToListOptionalTensor();
232+
{_fail_block(in_name + "_res")}
233+
auto {in_name} = {in_name}_res.get();
203234
c10::List<::std::optional<at::Tensor>> {out_name};
204235
for (auto {elem_name}: {in_name}) {{
205236
{out_name}.push_back({elem_name});
206237
}}
207238
#else
208-
auto {out_name} = {arg_name}.toListOptionalTensor();
239+
auto {res_name_list} = {arg_name}.tryToListOptionalTensor();
240+
{_fail_block(res_name_list)}
241+
auto {out_name} = {res_name_list}.get();
209242
#endif
210-
""".split(
211-
"\n"
212-
)
213-
)
243+
""".split("\n"))
214244
else:
215245
# use ArrayRef as default.
216246
vec_name = arg_name + "_vec"
217247
# need to bring vector instantiation out of scope so that ArrayRef has valid data
218248
decl.append(
219249
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
220250
)
221-
code.extend(
222-
f"""
251+
code.extend(f"""
223252
for (EValue {elem_name}: {in_name}) {{
224253
{connector.join(res_code)}
225254
{vec_name}.push_back({res_name});
226255
}}
227256
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
228-
""".split(
229-
"\n"
230-
)
231-
)
257+
""".split("\n"))
232258
return code, decl

runtime/kernel/kernel_runtime_context.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99
#pragma once
1010

11+
#include <cstdint>
12+
1113
#include <executorch/runtime/core/error.h>
1214
#include <executorch/runtime/core/event_tracer_hooks.h>
1315
#include <executorch/runtime/core/memory_allocator.h>
1416
#include <executorch/runtime/core/result.h>
1517
#include <executorch/runtime/platform/compiler.h>
18+
#include <executorch/runtime/platform/log.h>
1619

1720
namespace executorch {
1821
namespace ET_RUNTIME_NAMESPACE {
@@ -107,6 +110,31 @@ class KernelRuntimeContext {
107110
Error failure_state_ = Error::Ok;
108111
};
109112

113+
namespace internal {
114+
115+
// Cold path for codegen-emitted boxed kernel wrappers. Logs a diagnostic and
116+
// sets the kernel's failure state when an EValue arg unpack fails. The
117+
// wrapper must still `return` after calling this.
118+
#if defined(__GNUC__) || defined(__clang__)
119+
[[gnu::cold]]
120+
#endif
121+
inline void kernel_arg_fail(
122+
KernelRuntimeContext& context,
123+
Error error,
124+
const char* kernel_name,
125+
const char* arg_name,
126+
uint8_t actual_tag) {
127+
ET_LOG(
128+
Error,
129+
"%s: arg '%s' has unexpected EValue tag %u",
130+
kernel_name,
131+
arg_name,
132+
static_cast<unsigned>(actual_tag));
133+
context.fail(error);
134+
}
135+
136+
} // namespace internal
137+
110138
} // namespace ET_RUNTIME_NAMESPACE
111139
} // namespace executorch
112140

0 commit comments

Comments
 (0)