1313 Type ,
1414)
1515
16-
1716if TYPE_CHECKING :
1817 from collections .abc import Sequence
1918
@@ -32,20 +31,22 @@ def name(f: NativeFunction) -> str:
3231class Unboxing :
3332 """
3433 Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
35- A sample generated code:
34+ A sample generated code (abbreviated to one arg for readability) :
3635 // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
37- void mul_out(Span<EValue*> stack) {
36+ void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
3837 EValue& self = *stack[0];
39- EValue& other = *stack[1];
40- EValue& out = *stack[2];
41- const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
42- const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
43- torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
44-
38+ // ... other args ...
39+ auto self_base_res = self.tryTo<torch::executor::Tensor>();
40+ if (!self_base_res.ok()) {
41+ ::executorch::runtime::internal::kernel_arg_fail(
42+ context, self_base_res.error(), __func__, "self",
43+ static_cast<uint8_t>(self.tag));
44+ return;
45+ }
46+ const torch::executor::Tensor& self_base = self_base_res.get();
47+ // ... other unpacks ...
4548 EXECUTORCH_SCOPE_PROF("native_call_mul.out");
4649 torch::executor::mul_outf(self_base, other_base, out_base);
47-
48-
4950 }
5051 """
5152
@@ -115,8 +116,19 @@ def argumenttype_evalue_convert(
115116 def _gen_code_base_type (
116117 self , arg_name : str , out_name : str , ctype : CType
117118 ) -> tuple [list [str ], list [str ]]:
119+ # Use tryTo<T>() with a shared cold fail helper so every wrapper
120+ # logs a consistent diagnostic and propagates the error via
121+ # KernelRuntimeContext::fail() rather than aborting.
122+ res_name = f"{ out_name } _res"
118123 return [
119- f"{ ctype .cpp_type ()} { out_name } = { arg_name } .to<{ ctype .cpp_type (strip_ref = True )} >();"
124+ f"auto { res_name } = { arg_name } .tryTo<{ ctype .cpp_type (strip_ref = True )} >();" ,
125+ f"if (!{ res_name } .ok()) {{" ,
126+ " ::executorch::runtime::internal::kernel_arg_fail(" ,
127+ f' context, { res_name } .error(), __func__, "{ arg_name } ",' ,
128+ f" static_cast<uint8_t>({ arg_name } .tag));" ,
129+ " return;" ,
130+ "}" ,
131+ f"{ ctype .cpp_type ()} { out_name } = { res_name } .get();" ,
120132 ], []
121133
122134 def _gen_code_optional_type (
@@ -126,12 +138,20 @@ def _gen_code_optional_type(
126138 res_name , base_type , res_code , decl = self .argumenttype_evalue_convert (
127139 t .elem , in_name
128140 )
141+ # Use tryToOptional<T>() with the shared fail helper (see
142+ # _gen_code_base_type).
143+ opt_res_name = f"{ out_name } _res"
129144 return (
130145 f"""
131- auto { out_name } = { arg_name } .toOptional<{ base_type .cpp_type (strip_ref = True )} >();
132- """ .split (
133- "\n "
134- ),
146+ auto { opt_res_name } = { arg_name } .tryToOptional<{ base_type .cpp_type (strip_ref = True )} >();
147+ if (!{ opt_res_name } .ok()) {{
148+ ::executorch::runtime::internal::kernel_arg_fail(
149+ context, { opt_res_name } .error(), __func__, "{ arg_name } ",
150+ static_cast<uint8_t>({ arg_name } .tag));
151+ return;
152+ }}
153+ auto { out_name } = std::move({ opt_res_name } .get());
154+ """ .split ("\n " ),
135155 decl ,
136156 )
137157
@@ -145,88 +165,94 @@ def _gen_code_list_type(
145165 t .elem , elem_name
146166 )
147167
148- if isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .Tensor :
149- code .extend (
150- f"""
151- auto { out_name } = { arg_name } .toTensorList();
152- """ .split (
153- "\n "
154- )
168+ # Each branch uses the Result-returning tryToXList() accessor and
169+ # routes errors through the shared kernel_arg_fail helper; see
170+ # _gen_code_base_type for the rationale.
171+ res_name_list = f"{ out_name } _res"
172+
173+ def _fail_block (res : str ) -> str :
174+ # Cold fail path: log + context.fail() via the shared helper.
175+ return (
176+ f"if (!{ res } .ok()) {{\n "
177+ f" ::executorch::runtime::internal::kernel_arg_fail(\n "
178+ f' context, { res } .error(), __func__, "{ arg_name } ",\n '
179+ f" static_cast<uint8_t>({ arg_name } .tag));\n "
180+ f" return;\n "
181+ f" }}"
155182 )
183+
184+ if isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .Tensor :
185+ code .extend (f"""
186+ auto { res_name_list } = { arg_name } .tryToTensorList();
187+ { _fail_block (res_name_list )}
188+ auto { out_name } = { res_name_list } .get();
189+ """ .split ("\n " ))
156190 elif isinstance (t .elem , BaseType ) and (
157191 t .elem .name == BaseTy .int or t .elem .name == BaseTy .SymInt
158192 ):
159- code .extend (
160- f"""
161- auto { out_name } = { arg_name } .toIntList();
162- """ .split (
163- "\n "
164- )
165- )
193+ code .extend (f"""
194+ auto { res_name_list } = { arg_name } .tryToIntList();
195+ { _fail_block (res_name_list )}
196+ auto { out_name } = { res_name_list } .get();
197+ """ .split ("\n " ))
166198 elif isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .float :
167- code .extend (
168- f"""
169- auto { out_name } = { arg_name } .toDoubleList();
170- """ .split (
171- "\n "
172- )
173- )
199+ code .extend (f"""
200+ auto { res_name_list } = { arg_name } .tryToDoubleList();
201+ { _fail_block (res_name_list )}
202+ auto { out_name } = { res_name_list } .get();
203+ """ .split ("\n " ))
174204 elif isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .bool :
175205 # handle list type with size, e.g., bool[4]
176- code .extend (
177- f"""
206+ code .extend (f"""
178207#ifdef USE_ATEN_LIB
179208std::array<bool, { t .size } > { out_name } ;
180- auto { in_name } = { arg_name } .toBoolList();
209+ auto { in_name } _res = { arg_name } .tryToBoolList();
210+ { _fail_block (in_name + "_res" )}
211+ auto { in_name } = { in_name } _res.get();
181212size_t _i = 0;
182213for (auto { elem_name } : { in_name } ) {{
183214 { out_name } [_i++] = { elem_name } ;
184215}}
185216#else
186- auto { out_name } = { arg_name } .toBoolList();
217+ auto { res_name_list } = { arg_name } .tryToBoolList();
218+ { _fail_block (res_name_list )}
219+ auto { out_name } = { res_name_list } .get();
187220#endif
188- """ .split (
189- "\n "
190- )
191- )
221+ """ .split ("\n " ))
192222 # pytorch codegen:
193223 # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
194224 elif (
195225 isinstance (t .elem , OptionalType )
196226 and isinstance (t .elem .elem , BaseType )
197227 and t .elem .elem .name == BaseTy .Tensor
198228 ):
199- code .extend (
200- f"""
229+ code .extend (f"""
201230#ifdef USE_ATEN_LIB
202- auto { in_name } = { arg_name } .toListOptionalTensor();
231+ auto { in_name } _res = { arg_name } .tryToListOptionalTensor();
232+ { _fail_block (in_name + "_res" )}
233+ auto { in_name } = { in_name } _res.get();
203234c10::List<::std::optional<at::Tensor>> { out_name } ;
204235for (auto { elem_name } : { in_name } ) {{
205236 { out_name } .push_back({ elem_name } );
206237}}
207238#else
208- auto { out_name } = { arg_name } .toListOptionalTensor();
239+ auto { res_name_list } = { arg_name } .tryToListOptionalTensor();
240+ { _fail_block (res_name_list )}
241+ auto { out_name } = { res_name_list } .get();
209242#endif
210- """ .split (
211- "\n "
212- )
213- )
243+ """ .split ("\n " ))
214244 else :
215245 # use ArrayRef as default.
216246 vec_name = arg_name + "_vec"
217247 # need to bring vector instantiation out of scope so that ArrayRef has valid data
218248 decl .append (
219249 f"std::vector<{ res_ctype .cpp_type (strip_ref = True )} > { vec_name } ;"
220250 )
221- code .extend (
222- f"""
251+ code .extend (f"""
223252 for (EValue { elem_name } : { in_name } ) {{
224253 { connector .join (res_code )}
225254 { vec_name } .push_back({ res_name } );
226255 }}
227256 { ctype .cpp_type (strip_ref = True )} { out_name } ({ vec_name } );
228- """ .split (
229- "\n "
230- )
231- )
257+ """ .split ("\n " ))
232258 return code , decl
0 commit comments