1313 Type ,
1414)
1515
16-
1716if TYPE_CHECKING :
1817 from collections .abc import Sequence
1918
@@ -34,13 +33,19 @@ class Unboxing:
3433 Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
3534 A sample generated code:
3635 // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
37- void mul_out(Span<EValue*> stack) {
36+ void mul_out(KernelRuntimeContext& context, Span<EValue*> stack) {
3837 EValue& self = *stack[0];
3938 EValue& other = *stack[1];
4039 EValue& out = *stack[2];
41- const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
42- const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
43- torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
40+ auto self_base_res = self.tryTo<torch::executor::Tensor>();
41+ if (!self_base_res.ok()) { context.fail(self_base_res.error()); return; }
42+ const torch::executor::Tensor & self_base = self_base_res.get();
43+ auto other_base_res = other.tryTo<torch::executor::Tensor>();
44+ if (!other_base_res.ok()) { context.fail(other_base_res.error()); return; }
45+ const torch::executor::Tensor & other_base = other_base_res.get();
46+ auto out_base_res = out.tryTo<torch::executor::Tensor>();
47+ if (!out_base_res.ok()) { context.fail(out_base_res.error()); return; }
48+ torch::executor::Tensor & out_base = out_base_res.get();
4449
4550 EXECUTORCH_SCOPE_PROF("native_call_mul.out");
4651 torch::executor::mul_outf(self_base, other_base, out_base);
@@ -115,8 +120,15 @@ def argumenttype_evalue_convert(
115120 def _gen_code_base_type (
116121 self , arg_name : str , out_name : str , ctype : CType
117122 ) -> tuple [list [str ], list [str ]]:
123+ # Use the Result-returning tryTo<T>() instead of to<T>() so that a
124+ # malformed PTE with a mismatched EValue tag returns an error to the
125+ # caller via KernelRuntimeContext::fail() rather than aborting the
126+ # process.
127+ res_name = f"{ out_name } _res"
118128 return [
119- f"{ ctype .cpp_type ()} { out_name } = { arg_name } .to<{ ctype .cpp_type (strip_ref = True )} >();"
129+ f"auto { res_name } = { arg_name } .tryTo<{ ctype .cpp_type (strip_ref = True )} >();" ,
130+ f"if (!{ res_name } .ok()) {{ context.fail({ res_name } .error()); return; }}" ,
131+ f"{ ctype .cpp_type ()} { out_name } = { res_name } .get();" ,
120132 ], []
121133
122134 def _gen_code_optional_type (
@@ -126,12 +138,14 @@ def _gen_code_optional_type(
126138 res_name , base_type , res_code , decl = self .argumenttype_evalue_convert (
127139 t .elem , in_name
128140 )
141+ # Use tryToOptional<T>() to propagate tag mismatches as errors.
142+ opt_res_name = f"{ out_name } _res"
129143 return (
130144 f"""
131- auto { out_name } = { arg_name } .toOptional <{ base_type .cpp_type (strip_ref = True )} >();
132- """ . split (
133- " \n "
134- ),
145+ auto { opt_res_name } = { arg_name } .tryToOptional <{ base_type .cpp_type (strip_ref = True )} >();
146+ if (! { opt_res_name } .ok()) {{ context.fail( { opt_res_name } .error()); return; }}
147+ auto { out_name } = std::move( { opt_res_name } .get());
148+ """ . split ( " \n " ),
135149 decl ,
136150 )
137151
@@ -145,88 +159,82 @@ def _gen_code_list_type(
145159 t .elem , elem_name
146160 )
147161
162+ # Each branch uses the Result-returning tryToXList() accessor and
163+ # propagates errors via context.fail(); see _gen_code_base_type for
164+ # the rationale.
165+ res_name_list = f"{ out_name } _res"
148166 if isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .Tensor :
149- code .extend (
150- f"""
151- auto { out_name } = { arg_name } .toTensorList();
152- """ .split (
153- "\n "
154- )
155- )
167+ code .extend (f"""
168+ auto { res_name_list } = { arg_name } .tryToTensorList();
169+ if (!{ res_name_list } .ok()) {{ context.fail({ res_name_list } .error()); return; }}
170+ auto { out_name } = { res_name_list } .get();
171+ """ .split ("\n " ))
156172 elif isinstance (t .elem , BaseType ) and (
157173 t .elem .name == BaseTy .int or t .elem .name == BaseTy .SymInt
158174 ):
159- code .extend (
160- f"""
161- auto { out_name } = { arg_name } .toIntList();
162- """ .split (
163- "\n "
164- )
165- )
175+ code .extend (f"""
176+ auto { res_name_list } = { arg_name } .tryToIntList();
177+ if (!{ res_name_list } .ok()) {{ context.fail({ res_name_list } .error()); return; }}
178+ auto { out_name } = { res_name_list } .get();
179+ """ .split ("\n " ))
166180 elif isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .float :
167- code .extend (
168- f"""
169- auto { out_name } = { arg_name } .toDoubleList();
170- """ .split (
171- "\n "
172- )
173- )
181+ code .extend (f"""
182+ auto { res_name_list } = { arg_name } .tryToDoubleList();
183+ if (!{ res_name_list } .ok()) {{ context.fail({ res_name_list } .error()); return; }}
184+ auto { out_name } = { res_name_list } .get();
185+ """ .split ("\n " ))
174186 elif isinstance (t .elem , BaseType ) and t .elem .name == BaseTy .bool :
175187 # handle list type with size, e.g., bool[4]
176- code .extend (
177- f"""
188+ code .extend (f"""
178189#ifdef USE_ATEN_LIB
179190std::array<bool, { t .size } > { out_name } ;
180- auto { in_name } = { arg_name } .toBoolList();
191+ auto { in_name } _res = { arg_name } .tryToBoolList();
192+ if (!{ in_name } _res.ok()) {{ context.fail({ in_name } _res.error()); return; }}
193+ auto { in_name } = { in_name } _res.get();
181194size_t _i = 0;
182195for (auto { elem_name } : { in_name } ) {{
183196 { out_name } [_i++] = { elem_name } ;
184197}}
185198#else
186- auto { out_name } = { arg_name } .toBoolList();
199+ auto { res_name_list } = { arg_name } .tryToBoolList();
200+ if (!{ res_name_list } .ok()) {{ context.fail({ res_name_list } .error()); return; }}
201+ auto { out_name } = { res_name_list } .get();
187202#endif
188- """ .split (
189- "\n "
190- )
191- )
203+ """ .split ("\n " ))
192204 # pytorch codegen:
193205 # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
194206 elif (
195207 isinstance (t .elem , OptionalType )
196208 and isinstance (t .elem .elem , BaseType )
197209 and t .elem .elem .name == BaseTy .Tensor
198210 ):
199- code .extend (
200- f"""
211+ code .extend (f"""
201212#ifdef USE_ATEN_LIB
202- auto { in_name } = { arg_name } .toListOptionalTensor();
213+ auto { in_name } _res = { arg_name } .tryToListOptionalTensor();
214+ if (!{ in_name } _res.ok()) {{ context.fail({ in_name } _res.error()); return; }}
215+ auto { in_name } = { in_name } _res.get();
203216c10::List<::std::optional<at::Tensor>> { out_name } ;
204217for (auto { elem_name } : { in_name } ) {{
205218 { out_name } .push_back({ elem_name } );
206219}}
207220#else
208- auto { out_name } = { arg_name } .toListOptionalTensor();
221+ auto { res_name_list } = { arg_name } .tryToListOptionalTensor();
222+ if (!{ res_name_list } .ok()) {{ context.fail({ res_name_list } .error()); return; }}
223+ auto { out_name } = { res_name_list } .get();
209224#endif
210- """ .split (
211- "\n "
212- )
213- )
225+ """ .split ("\n " ))
214226 else :
215227 # use ArrayRef as default.
216228 vec_name = arg_name + "_vec"
217229 # need to bring vector instantiation out of scope so that ArrayRef has valid data
218230 decl .append (
219231 f"std::vector<{ res_ctype .cpp_type (strip_ref = True )} > { vec_name } ;"
220232 )
221- code .extend (
222- f"""
233+ code .extend (f"""
223234 for (EValue { elem_name } : { in_name } ) {{
224235 { connector .join (res_code )}
225236 { vec_name } .push_back({ res_name } );
226237 }}
227238 { ctype .cpp_type (strip_ref = True )} { out_name } ({ vec_name } );
228- """ .split (
229- "\n "
230- )
231- )
239+ """ .split ("\n " ))
232240 return code , decl
0 commit comments