88
99#pragma once
1010#include < executorch/runtime/core/exec_aten/exec_aten.h>
11+ #include < executorch/runtime/core/result.h>
1112#include < executorch/runtime/core/tag.h>
1213#include < executorch/runtime/platform/assert.h>
1314
@@ -193,6 +194,13 @@ struct EValue {
193194 return payload.copyable_union .as_int ;
194195 }
195196
197+ Result<int64_t > tryToInt () const {
198+ if (!isInt ()) {
199+ return Error::InvalidType;
200+ }
201+ return payload.copyable_union .as_int ;
202+ }
203+
196204 /* ***** Double Type ******/
197205 /* implicit*/ EValue(double d) : tag(Tag::Double) {
198206 payload.copyable_union .as_double = d;
@@ -207,6 +215,13 @@ struct EValue {
207215 return payload.copyable_union .as_double ;
208216 }
209217
218+ Result<double > tryToDouble () const {
219+ if (!isDouble ()) {
220+ return Error::InvalidType;
221+ }
222+ return payload.copyable_union .as_double ;
223+ }
224+
210225 /* ***** Bool Type ******/
211226 /* implicit*/ EValue(bool b) : tag(Tag::Bool) {
212227 payload.copyable_union .as_bool = b;
@@ -221,6 +236,13 @@ struct EValue {
221236 return payload.copyable_union .as_bool ;
222237 }
223238
239+ Result<bool > tryToBool () const {
240+ if (!isBool ()) {
241+ return Error::InvalidType;
242+ }
243+ return payload.copyable_union .as_bool ;
244+ }
245+
224246 /* ***** Scalar Type ******/
225247 // / Construct an EValue using the implicit value of a Scalar.
226248 /* implicit*/ EValue(executorch::aten::Scalar s) {
@@ -256,6 +278,19 @@ struct EValue {
256278 }
257279 }
258280
281+ Result<executorch::aten::Scalar> tryToScalar () const {
282+ if (isDouble ()) {
283+ return executorch::aten::Scalar (payload.copyable_union .as_double );
284+ }
285+ if (isInt ()) {
286+ return executorch::aten::Scalar (payload.copyable_union .as_int );
287+ }
288+ if (isBool ()) {
289+ return executorch::aten::Scalar (payload.copyable_union .as_bool );
290+ }
291+ return Error::InvalidType;
292+ }
293+
259294 /* ***** Tensor Type ******/
260295 /* implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) {
261296 // When built in aten mode, at::Tensor has a non trivial constructor
@@ -305,6 +340,13 @@ struct EValue {
305340 return payload.as_tensor ;
306341 }
307342
343+ Result<executorch::aten::Tensor> tryToTensor () const {
344+ if (!isTensor ()) {
345+ return Error::InvalidType;
346+ }
347+ return payload.as_tensor ;
348+ }
349+
308350 /* ***** String Type ******/
309351 /* implicit*/ EValue(executorch::aten::ArrayRef<char >* s) : tag(Tag::String) {
310352 ET_CHECK_MSG (s != nullptr , " ArrayRef<char> pointer cannot be null" );
@@ -325,6 +367,18 @@ struct EValue {
325367 payload.copyable_union .as_string_ptr ->size ());
326368 }
327369
370+ Result<std::string_view> tryToString () const {
371+ if (!isString ()) {
372+ return Error::InvalidType;
373+ }
374+ if (payload.copyable_union .as_string_ptr == nullptr ) {
375+ return Error::InvalidState;
376+ }
377+ return std::string_view (
378+ payload.copyable_union .as_string_ptr ->data (),
379+ payload.copyable_union .as_string_ptr ->size ());
380+ }
381+
328382 /* ***** Int List Type ******/
329383 /* implicit*/ EValue(BoxedEvalueList<int64_t >* i) : tag(Tag::ListInt) {
330384 ET_CHECK_MSG (
@@ -344,6 +398,16 @@ struct EValue {
344398 return (payload.copyable_union .as_int_list_ptr )->get ();
345399 }
346400
401+ Result<executorch::aten::ArrayRef<int64_t >> tryToIntList () const {
402+ if (!isIntList ()) {
403+ return Error::InvalidType;
404+ }
405+ if (payload.copyable_union .as_int_list_ptr == nullptr ) {
406+ return Error::InvalidState;
407+ }
408+ return (payload.copyable_union .as_int_list_ptr )->get ();
409+ }
410+
347411 /* ***** Bool List Type ******/
348412 /* implicit*/ EValue(executorch::aten::ArrayRef<bool >* b)
349413 : tag(Tag::ListBool) {
@@ -363,6 +427,16 @@ struct EValue {
363427 return *(payload.copyable_union .as_bool_list_ptr );
364428 }
365429
430+ Result<executorch::aten::ArrayRef<bool >> tryToBoolList () const {
431+ if (!isBoolList ()) {
432+ return Error::InvalidType;
433+ }
434+ if (payload.copyable_union .as_bool_list_ptr == nullptr ) {
435+ return Error::InvalidState;
436+ }
437+ return *(payload.copyable_union .as_bool_list_ptr );
438+ }
439+
366440 /* ***** Double List Type ******/
367441 /* implicit*/ EValue(executorch::aten::ArrayRef<double >* d)
368442 : tag(Tag::ListDouble) {
@@ -382,6 +456,16 @@ struct EValue {
382456 return *(payload.copyable_union .as_double_list_ptr );
383457 }
384458
459+ Result<executorch::aten::ArrayRef<double >> tryToDoubleList () const {
460+ if (!isDoubleList ()) {
461+ return Error::InvalidType;
462+ }
463+ if (payload.copyable_union .as_double_list_ptr == nullptr ) {
464+ return Error::InvalidState;
465+ }
466+ return *(payload.copyable_union .as_double_list_ptr );
467+ }
468+
385469 /* ***** Tensor List Type ******/
386470 /* implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
387471 : tag(Tag::ListTensor) {
@@ -402,6 +486,17 @@ struct EValue {
402486 return payload.copyable_union .as_tensor_list_ptr ->get ();
403487 }
404488
489+ Result<executorch::aten::ArrayRef<executorch::aten::Tensor>> tryToTensorList ()
490+ const {
491+ if (!isTensorList ()) {
492+ return Error::InvalidType;
493+ }
494+ if (payload.copyable_union .as_tensor_list_ptr == nullptr ) {
495+ return Error::InvalidState;
496+ }
497+ return payload.copyable_union .as_tensor_list_ptr ->get ();
498+ }
499+
405500 /* ***** List Optional Tensor Type ******/
406501 /* implicit*/ EValue(
407502 BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
@@ -426,26 +521,60 @@ struct EValue {
426521 return payload.copyable_union .as_list_optional_tensor_ptr ->get ();
427522 }
428523
524+ Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
525+ tryToListOptionalTensor () const {
526+ if (!isListOptionalTensor ()) {
527+ return Error::InvalidType;
528+ }
529+ if (payload.copyable_union .as_list_optional_tensor_ptr == nullptr ) {
530+ return Error::InvalidState;
531+ }
532+ return payload.copyable_union .as_list_optional_tensor_ptr ->get ();
533+ }
534+
429535 /* ***** ScalarType Type ******/
430536 executorch::aten::ScalarType toScalarType () const {
431537 ET_CHECK_MSG (isInt (), " EValue is not a ScalarType." );
432538 return static_cast <executorch::aten::ScalarType>(
433539 payload.copyable_union .as_int );
434540 }
435541
542+ Result<executorch::aten::ScalarType> tryToScalarType () const {
543+ if (!isInt ()) {
544+ return Error::InvalidType;
545+ }
546+ return static_cast <executorch::aten::ScalarType>(
547+ payload.copyable_union .as_int );
548+ }
549+
436550 /* ***** MemoryFormat Type ******/
437551 executorch::aten::MemoryFormat toMemoryFormat () const {
438552 ET_CHECK_MSG (isInt (), " EValue is not a MemoryFormat." );
439553 return static_cast <executorch::aten::MemoryFormat>(
440554 payload.copyable_union .as_int );
441555 }
442556
557+ Result<executorch::aten::MemoryFormat> tryToMemoryFormat () const {
558+ if (!isInt ()) {
559+ return Error::InvalidType;
560+ }
561+ return static_cast <executorch::aten::MemoryFormat>(
562+ payload.copyable_union .as_int );
563+ }
564+
443565 /* ***** Layout Type ******/
444566 executorch::aten::Layout toLayout () const {
445567 ET_CHECK_MSG (isInt (), " EValue is not a Layout." );
446568 return static_cast <executorch::aten::Layout>(payload.copyable_union .as_int );
447569 }
448570
571+ Result<executorch::aten::Layout> tryToLayout () const {
572+ if (!isInt ()) {
573+ return Error::InvalidType;
574+ }
575+ return static_cast <executorch::aten::Layout>(payload.copyable_union .as_int );
576+ }
577+
449578 /* ***** Device Type ******/
450579 executorch::aten::Device toDevice () const {
451580 ET_CHECK_MSG (isInt (), " EValue is not a Device." );
@@ -455,13 +584,32 @@ struct EValue {
455584 -1 );
456585 }
457586
587+ Result<executorch::aten::Device> tryToDevice () const {
588+ if (!isInt ()) {
589+ return Error::InvalidType;
590+ }
591+ return executorch::aten::Device (
592+ static_cast <executorch::aten::DeviceType>(
593+ payload.copyable_union .as_int ),
594+ -1 );
595+ }
596+
458597 template <typename T>
459598 T to () &&;
460599 template <typename T>
461600 typename internal::evalue_to_const_ref_overload_return<T>::type to () const &;
462601 template <typename T>
463602 typename internal::evalue_to_ref_overload_return<T>::type to () &;
464603
604+ /* *
605+ * Result-returning equivalent of `to<T>()`. Returns `Error::InvalidType` on
606+ * tag mismatch instead of aborting, so callers processing untrusted EValues
607+ * (e.g., from a `.pte`) can surface the error rather than terminate.
608+ * Specializations are defined below via `EVALUE_DEFINE_TRY_TO`.
609+ */
610+ template <typename T>
611+ Result<T> tryTo () const ;
612+
465613 /* *
466614 * Converts the EValue to an optional object that can represent both T and
467615 * an uninitialized state.
@@ -474,6 +622,23 @@ struct EValue {
474622 return this ->to <T>();
475623 }
476624
625+ /* *
626+ * Result-returning equivalent of `toOptional<T>()`. None maps to an empty
627+ * optional; any other tag that doesn't match T propagates `tryTo<T>()`'s
628+ * error (`Error::InvalidType`).
629+ */
630+ template <typename T>
631+ inline Result<std::optional<T>> tryToOptional () const {
632+ if (this ->isNone ()) {
633+ return std::optional<T>(executorch::aten::nullopt );
634+ }
635+ auto r = this ->tryTo <T>();
636+ if (!r.ok ()) {
637+ return r.error ();
638+ }
639+ return std::optional<T>(std::move (r.get ()));
640+ }
641+
477642 private:
478643 // Pre cond: the payload value has had its destructor called
479644 void clearToNone () noexcept {
@@ -524,7 +689,7 @@ struct EValue {
524689
525690#define EVALUE_DEFINE_TO (T, method_name ) \
526691 template <> \
527- inline T EValue::to<T>()&& { \
692+ inline T EValue::to<T>() && { \
528693 return static_cast <T>(std::move (*this ).method_name ()); \
529694 } \
530695 template <> \
@@ -538,7 +703,7 @@ struct EValue {
538703 template <> \
539704 inline ::executorch::runtime::internal::evalue_to_ref_overload_return< \
540705 T>::type \
541- EValue::to<T>()& { \
706+ EValue::to<T>() & { \
542707 typedef ::executorch::runtime::internal::evalue_to_ref_overload_return< \
543708 T>::type return_type; \
544709 return static_cast <return_type>(this ->method_name ()); \
@@ -591,6 +756,33 @@ EVALUE_DEFINE_TO(
591756 toListOptionalTensor)
592757#undef EVALUE_DEFINE_TO
593758
759+ #define EVALUE_DEFINE_TRY_TO (T, method_name ) \
760+ template <> \
761+ inline Result<T> EValue::tryTo<T>() const { \
762+ return this ->method_name (); \
763+ }
764+
765+ EVALUE_DEFINE_TRY_TO (executorch::aten::Scalar, tryToScalar)
766+ EVALUE_DEFINE_TRY_TO (int64_t , tryToInt)
767+ EVALUE_DEFINE_TRY_TO (bool , tryToBool)
768+ EVALUE_DEFINE_TRY_TO (double , tryToDouble)
769+ EVALUE_DEFINE_TRY_TO (std::string_view, tryToString)
770+ EVALUE_DEFINE_TRY_TO (executorch::aten::ScalarType, tryToScalarType)
771+ EVALUE_DEFINE_TRY_TO (executorch::aten::MemoryFormat, tryToMemoryFormat)
772+ EVALUE_DEFINE_TRY_TO (executorch::aten::Layout, tryToLayout)
773+ EVALUE_DEFINE_TRY_TO (executorch::aten::Device, tryToDevice)
774+ EVALUE_DEFINE_TRY_TO (executorch::aten::Tensor, tryToTensor)
775+ EVALUE_DEFINE_TRY_TO (executorch::aten::ArrayRef<int64_t >, tryToIntList)
776+ EVALUE_DEFINE_TRY_TO (executorch::aten::ArrayRef<double >, tryToDoubleList)
777+ EVALUE_DEFINE_TRY_TO (executorch::aten::ArrayRef<bool >, tryToBoolList)
778+ EVALUE_DEFINE_TRY_TO (
779+ executorch::aten::ArrayRef<executorch::aten::Tensor>,
780+ tryToTensorList)
781+ EVALUE_DEFINE_TRY_TO (
782+ executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>,
783+ tryToListOptionalTensor)
784+ #undef EVALUE_DEFINE_TRY_TO
785+
594786template <typename T>
595787executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
596788 for (typename executorch::aten::ArrayRef<T>::size_type i = 0 ;
0 commit comments