Skip to content

Commit e6f6215

Browse files
author
Github Executorch
committed
Update
[ghstack-poisoned]
1 parent f478d75 commit e6f6215

1 file changed

Lines changed: 61 additions & 53 deletions

File tree

codegen/api/unboxing.py

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
Type,
1414
)
1515

16-
1716
if TYPE_CHECKING:
1817
from collections.abc import Sequence
1918

@@ -34,13 +33,19 @@ class Unboxing:
3433
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
3534
A sample generated code:
3635
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
37-
void mul_out(Span<EValue*> stack) {
36+
void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
3837
EValue& self = *stack[0];
3938
EValue& other = *stack[1];
4039
EValue& out = *stack[2];
41-
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
42-
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
43-
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
40+
auto self_base_res = self.tryTo<torch::executor::Tensor>();
41+
if (!self_base_res.ok()) { context.fail(self_base_res.error()); return; }
42+
const torch::executor::Tensor & self_base = self_base_res.get();
43+
auto other_base_res = other.tryTo<torch::executor::Tensor>();
44+
if (!other_base_res.ok()) { context.fail(other_base_res.error()); return; }
45+
const torch::executor::Tensor & other_base = other_base_res.get();
46+
auto out_base_res = out.tryTo<torch::executor::Tensor>();
47+
if (!out_base_res.ok()) { context.fail(out_base_res.error()); return; }
48+
torch::executor::Tensor & out_base = out_base_res.get();
4449
4550
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
4651
torch::executor::mul_outf(self_base, other_base, out_base);
@@ -115,8 +120,15 @@ def argumenttype_evalue_convert(
115120
def _gen_code_base_type(
116121
self, arg_name: str, out_name: str, ctype: CType
117122
) -> tuple[list[str], list[str]]:
123+
# Use the Result-returning tryTo<T>() instead of to<T>() so that a
124+
# malformed PTE with a mismatched EValue tag returns an error to the
125+
# caller via KernelRuntimeContext::fail() rather than aborting the
126+
# process.
127+
res_name = f"{out_name}_res"
118128
return [
119-
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
129+
f"auto {res_name} = {arg_name}.tryTo<{ctype.cpp_type(strip_ref=True)}>();",
130+
f"if (!{res_name}.ok()) {{ context.fail({res_name}.error()); return; }}",
131+
f"{ctype.cpp_type()} {out_name} = {res_name}.get();",
120132
], []
121133

122134
def _gen_code_optional_type(
@@ -126,12 +138,14 @@ def _gen_code_optional_type(
126138
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
127139
t.elem, in_name
128140
)
141+
# Use tryToOptional<T>() to propagate tag mismatches as errors.
142+
opt_res_name = f"{out_name}_res"
129143
return (
130144
f"""
131-
auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
132-
""".split(
133-
"\n"
134-
),
145+
auto {opt_res_name} = {arg_name}.tryToOptional<{base_type.cpp_type(strip_ref=True)}>();
146+
if (!{opt_res_name}.ok()) {{ context.fail({opt_res_name}.error()); return; }}
147+
auto {out_name} = std::move({opt_res_name}.get());
148+
""".split("\n"),
135149
decl,
136150
)
137151

@@ -145,88 +159,82 @@ def _gen_code_list_type(
145159
t.elem, elem_name
146160
)
147161

162+
# Each branch uses the Result-returning tryToXList() accessor and
163+
# propagates errors via context.fail(); see _gen_code_base_type for
164+
# the rationale.
165+
res_name_list = f"{out_name}_res"
148166
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
149-
code.extend(
150-
f"""
151-
auto {out_name} = {arg_name}.toTensorList();
152-
""".split(
153-
"\n"
154-
)
155-
)
167+
code.extend(f"""
168+
auto {res_name_list} = {arg_name}.tryToTensorList();
169+
if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }}
170+
auto {out_name} = {res_name_list}.get();
171+
""".split("\n"))
156172
elif isinstance(t.elem, BaseType) and (
157173
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
158174
):
159-
code.extend(
160-
f"""
161-
auto {out_name} = {arg_name}.toIntList();
162-
""".split(
163-
"\n"
164-
)
165-
)
175+
code.extend(f"""
176+
auto {res_name_list} = {arg_name}.tryToIntList();
177+
if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }}
178+
auto {out_name} = {res_name_list}.get();
179+
""".split("\n"))
166180
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
167-
code.extend(
168-
f"""
169-
auto {out_name} = {arg_name}.toDoubleList();
170-
""".split(
171-
"\n"
172-
)
173-
)
181+
code.extend(f"""
182+
auto {res_name_list} = {arg_name}.tryToDoubleList();
183+
if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }}
184+
auto {out_name} = {res_name_list}.get();
185+
""".split("\n"))
174186
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
175187
# handle list type with size, e.g., bool[4]
176-
code.extend(
177-
f"""
188+
code.extend(f"""
178189
#ifdef USE_ATEN_LIB
179190
std::array<bool, {t.size}> {out_name};
180-
auto {in_name} = {arg_name}.toBoolList();
191+
auto {in_name}_res = {arg_name}.tryToBoolList();
192+
if (!{in_name}_res.ok()) {{ context.fail({in_name}_res.error()); return; }}
193+
auto {in_name} = {in_name}_res.get();
181194
size_t _i = 0;
182195
for (auto {elem_name}: {in_name}) {{
183196
{out_name}[_i++] = {elem_name};
184197
}}
185198
#else
186-
auto {out_name} = {arg_name}.toBoolList();
199+
auto {res_name_list} = {arg_name}.tryToBoolList();
200+
if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }}
201+
auto {out_name} = {res_name_list}.get();
187202
#endif
188-
""".split(
189-
"\n"
190-
)
191-
)
203+
""".split("\n"))
192204
# pytorch codegen:
193205
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
194206
elif (
195207
isinstance(t.elem, OptionalType)
196208
and isinstance(t.elem.elem, BaseType)
197209
and t.elem.elem.name == BaseTy.Tensor
198210
):
199-
code.extend(
200-
f"""
211+
code.extend(f"""
201212
#ifdef USE_ATEN_LIB
202-
auto {in_name} = {arg_name}.toListOptionalTensor();
213+
auto {in_name}_res = {arg_name}.tryToListOptionalTensor();
214+
if (!{in_name}_res.ok()) {{ context.fail({in_name}_res.error()); return; }}
215+
auto {in_name} = {in_name}_res.get();
203216
c10::List<::std::optional<at::Tensor>> {out_name};
204217
for (auto {elem_name}: {in_name}) {{
205218
{out_name}.push_back({elem_name});
206219
}}
207220
#else
208-
auto {out_name} = {arg_name}.toListOptionalTensor();
221+
auto {res_name_list} = {arg_name}.tryToListOptionalTensor();
222+
if (!{res_name_list}.ok()) {{ context.fail({res_name_list}.error()); return; }}
223+
auto {out_name} = {res_name_list}.get();
209224
#endif
210-
""".split(
211-
"\n"
212-
)
213-
)
225+
""".split("\n"))
214226
else:
215227
# use ArrayRef as default.
216228
vec_name = arg_name + "_vec"
217229
# need to bring vector instantiation out of scope so that ArrayRef has valid data
218230
decl.append(
219231
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
220232
)
221-
code.extend(
222-
f"""
233+
code.extend(f"""
223234
for (EValue {elem_name}: {in_name}) {{
224235
{connector.join(res_code)}
225236
{vec_name}.push_back({res_name});
226237
}}
227238
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
228-
""".split(
229-
"\n"
230-
)
231-
)
239+
""".split("\n"))
232240
return code, decl

0 commit comments

Comments
 (0)