1313 Type ,
1414)
1515
16-
1716if TYPE_CHECKING :
1817 from collections .abc import Sequence
1918
@@ -34,18 +33,20 @@ class Unboxing:
3433 Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
3534 A sample generated code:
3635 // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
37- void mul_out(Span<EValue*> stack) {
36+ void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
3837 EValue& self = *stack[0];
39- EValue& other = *stack[1];
40- EValue& out = *stack[2];
41- const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
42- const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
43- torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
44-
38+ // ... other args ...
39+ auto self_base_res = self.tryTo<torch::executor::Tensor>();
40+ if (!self_base_res.ok()) {
41+ ::executorch::runtime::internal::kernel_arg_fail(
42+ context, self_base_res.error(), __func__, "self",
43+ static_cast<uint8_t>(self.tag));
44+ return;
45+ }
46+ const torch::executor::Tensor& self_base = self_base_res.get();
47+ // ... other unpacks ...
4548 EXECUTORCH_SCOPE_PROF("native_call_mul.out");
4649 torch::executor::mul_outf(self_base, other_base, out_base);
47-
48-
4950 }
5051 """
5152
@@ -115,8 +116,18 @@ def argumenttype_evalue_convert(
115116 def _gen_code_base_type (
116117 self , arg_name : str , out_name : str , ctype : CType
117118 ) -> tuple [list [str ], list [str ]]:
119+ # Use tryTo<T>(). On error, call the kernel_arg_fail(..), which
120+ # logs the error and sets context.fail(). Then return.
121+ res_name = f"{ out_name } _res"
118122 return [
119- f"{ ctype .cpp_type ()} { out_name } = { arg_name } .to<{ ctype .cpp_type (strip_ref = True )} >();"
123+ f"auto { res_name } = { arg_name } .tryTo<{ ctype .cpp_type (strip_ref = True )} >();" ,
124+ f"if (!{ res_name } .ok()) {{" ,
125+ " ::executorch::runtime::internal::kernel_arg_fail(" ,
126+ f' context, { res_name } .error(), __func__, "{ arg_name } ",' ,
127+ f" static_cast<uint8_t>({ arg_name } .tag));" ,
128+ " return;" ,
129+ "}" ,
130+ f"{ ctype .cpp_type ()} { out_name } = { res_name } .get();" ,
120131 ], []
121132
122133 def _gen_code_optional_type (
@@ -126,12 +137,20 @@ def _gen_code_optional_type(
126137 res_name , base_type , res_code , decl = self .argumenttype_evalue_convert (
127138 t .elem , in_name
128139 )
140+ # Use tryToOptional<T>(). On error, call the kernel_arg_fail(..), which
141+ # logs the error and sets context.fail(). Then return.
142+ opt_res_name = f"{ out_name } _res"
129143 return (
130144 f"""
131- auto { out_name } = { arg_name } .toOptional<{ base_type .cpp_type (strip_ref = True )} >();
132- """ .split (
133- "\n "
134- ),
145+ auto { opt_res_name } = { arg_name } .tryToOptional<{ base_type .cpp_type (strip_ref = True )} >();
146+ if (!{ opt_res_name } .ok()) {{
147+ ::executorch::runtime::internal::kernel_arg_fail(
148+ context, { opt_res_name } .error(), __func__, "{ arg_name } ",
149+ static_cast<uint8_t>({ arg_name } .tag));
150+ return;
151+ }}
152+ auto { out_name } = std::move({ opt_res_name } .get());
153+ """ .split ("\n " ),
135154 decl ,
136155 )
137156
@@ -145,88 +164,92 @@ def _gen_code_list_type(
145164 t .elem , elem_name
146165 )
147166
148- if isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .Tensor :
149- code .extend (
150- f"""
151- auto { out_name } = { arg_name } .toTensorList();
152- """ .split (
153- "\n "
154- )
167+ # Each branch uses the Result-returning tryToXList() accessor and
168+ # routes errors through kernel_arg_fail.
169+ res_name_list = f"{ out_name } _res"
170+
171+ def _fail_block (res : str ) -> str :
172+ return (
173+ f"if (!{ res } .ok()) {{\n "
174+ f" ::executorch::runtime::internal::kernel_arg_fail(\n "
175+ f' context, { res } .error(), __func__, "{ arg_name } ",\n '
176+ f" static_cast<uint8_t>({ arg_name } .tag));\n "
177+ f" return;\n "
178+ f" }}"
155179 )
180+
181+ if isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .Tensor :
182+ code .extend (f"""
183+ auto { res_name_list } = { arg_name } .tryToTensorList();
184+ { _fail_block (res_name_list )}
185+ auto { out_name } = { res_name_list } .get();
186+ """ .split ("\n " ))
156187 elif isinstance (t .elem , BaseType ) and (
157188 t .elem .name == BaseTy .int or t .elem .name == BaseTy .SymInt
158189 ):
159- code .extend (
160- f"""
161- auto { out_name } = { arg_name } .toIntList();
162- """ .split (
163- "\n "
164- )
165- )
190+ code .extend (f"""
191+ auto { res_name_list } = { arg_name } .tryToIntList();
192+ { _fail_block (res_name_list )}
193+ auto { out_name } = { res_name_list } .get();
194+ """ .split ("\n " ))
166195 elif isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .float :
167- code .extend (
168- f"""
169- auto { out_name } = { arg_name } .toDoubleList();
170- """ .split (
171- "\n "
172- )
173- )
196+ code .extend (f"""
197+ auto { res_name_list } = { arg_name } .tryToDoubleList();
198+ { _fail_block (res_name_list )}
199+ auto { out_name } = { res_name_list } .get();
200+ """ .split ("\n " ))
174201 elif isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .bool :
175202 # handle list type with size, e.g., bool[4]
176- code .extend (
177- f"""
203+ code .extend (f"""
178204#ifdef USE_ATEN_LIB
179205std::array<bool, { t .size } > { out_name } ;
180- auto { in_name } = { arg_name } .toBoolList();
206+ auto { in_name } _res = { arg_name } .tryToBoolList();
207+ { _fail_block (in_name + "_res" )}
208+ auto { in_name } = { in_name } _res.get();
181209size_t _i = 0;
182210for (auto { elem_name } : { in_name } ) {{
183211 { out_name } [_i++] = { elem_name } ;
184212}}
185213#else
186- auto { out_name } = { arg_name } .toBoolList();
214+ auto { res_name_list } = { arg_name } .tryToBoolList();
215+ { _fail_block (res_name_list )}
216+ auto { out_name } = { res_name_list } .get();
187217#endif
188- """ .split (
189- "\n "
190- )
191- )
218+ """ .split ("\n " ))
192219 # pytorch codegen:
193220 # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
194221 elif (
195222 isinstance (t .elem , OptionalType )
196223 and isinstance (t .elem .elem , BaseType )
197224 and t .elem .elem .name == BaseTy .Tensor
198225 ):
199- code .extend (
200- f"""
226+ code .extend (f"""
201227#ifdef USE_ATEN_LIB
202- auto { in_name } = { arg_name } .toListOptionalTensor();
228+ auto { in_name } _res = { arg_name } .tryToListOptionalTensor();
229+ { _fail_block (in_name + "_res" )}
230+ auto { in_name } = { in_name } _res.get();
203231c10::List<::std::optional<at::Tensor>> { out_name } ;
204232for (auto { elem_name } : { in_name } ) {{
205233 { out_name } .push_back({ elem_name } );
206234}}
207235#else
208- auto { out_name } = { arg_name } .toListOptionalTensor();
236+ auto { res_name_list } = { arg_name } .tryToListOptionalTensor();
237+ { _fail_block (res_name_list )}
238+ auto { out_name } = { res_name_list } .get();
209239#endif
210- """ .split (
211- "\n "
212- )
213- )
240+ """ .split ("\n " ))
214241 else :
215242 # use ArrayRef as default.
216243 vec_name = arg_name + "_vec"
217244 # need to bring vector instantiation out of scope so that ArrayRef has valid data
218245 decl .append (
219246 f"std::vector<{ res_ctype .cpp_type (strip_ref = True )} > { vec_name } ;"
220247 )
221- code .extend (
222- f"""
248+ code .extend (f"""
223249 for (EValue { elem_name } : { in_name } ) {{
224250 { connector .join (res_code )}
225251 { vec_name } .push_back({ res_name } );
226252 }}
227253 { ctype .cpp_type (strip_ref = True )} { out_name } ({ vec_name } );
228- """ .split (
229- "\n "
230- )
231- )
254+ """ .split ("\n " ))
232255 return code , decl
0 commit comments