@@ -72,6 +72,16 @@ class BoxedEvalueList {
7272 */
7373 executorch::aten::ArrayRef<T> get () const ;
7474
75+ /* *
76+ * Result-returning counterpart of get(). Validates each wrapped EValue's
77+ * tag before materializing; returns Error::InvalidType if any element's
78+ * tag does not match T and Error::InvalidState if any element pointer is
79+ * null. Use this when materializing lists from untrusted .pte data so that
80+ * a malformed program cannot force a process abort inside to<T>() /
81+ * ET_CHECK.
82+ */
83+ Result<executorch::aten::ArrayRef<T>> tryGet () const ;
84+
7585 /* *
7686 * Destroys the unwrapped elements without re-dereferencing wrapped_vals_.
7787 * This is safe to call during EValue destruction because it does not
@@ -108,6 +118,10 @@ template <>
108118executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
109119BoxedEvalueList<std::optional<executorch::aten::Tensor>>::get() const ;
110120
121+ template <>
122+ Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
123+ BoxedEvalueList<std::optional<executorch::aten::Tensor>>::tryGet() const ;
124+
111125// Aggregate typing system similar to IValue only slimmed down with less
112126// functionality, no dependencies on atomic, and fewer supported types to better
113127// suit embedded systems (ie no intrusive ptr)
@@ -405,7 +419,7 @@ struct EValue {
405419 if (payload.copyable_union .as_int_list_ptr == nullptr ) {
406420 return Error::InvalidState;
407421 }
408- return (payload.copyable_union .as_int_list_ptr )->get ();
422+ return (payload.copyable_union .as_int_list_ptr )->tryGet ();
409423 }
410424
411425 /* ***** Bool List Type ******/
@@ -494,7 +508,7 @@ struct EValue {
494508 if (payload.copyable_union .as_tensor_list_ptr == nullptr ) {
495509 return Error::InvalidState;
496510 }
497- return payload.copyable_union .as_tensor_list_ptr ->get ();
511+ return payload.copyable_union .as_tensor_list_ptr ->tryGet ();
498512 }
499513
500514 /* ***** List Optional Tensor Type ******/
@@ -529,7 +543,7 @@ struct EValue {
529543 if (payload.copyable_union .as_list_optional_tensor_ptr == nullptr ) {
530544 return Error::InvalidState;
531545 }
532- return payload.copyable_union .as_list_optional_tensor_ptr ->get ();
546+ return payload.copyable_union .as_list_optional_tensor_ptr ->tryGet ();
533547 }
534548
535549 /* ***** ScalarType Type ******/
@@ -630,7 +644,7 @@ struct EValue {
630644 template <typename T>
631645 inline Result<std::optional<T>> tryToOptional () const {
632646 if (this ->isNone ()) {
633- return std::optional<T>(executorch::aten ::nullopt );
647+ return std::optional<T>(std ::nullopt );
634648 }
635649 auto r = this ->tryTo <T>();
636650 if (!r.ok ()) {
@@ -771,13 +785,39 @@ EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType)
771785EVALUE_DEFINE_TRY_TO (executorch::aten::MemoryFormat, tryToMemoryFormat)
772786EVALUE_DEFINE_TRY_TO (executorch::aten::Layout, tryToLayout)
773787EVALUE_DEFINE_TRY_TO (executorch::aten::Device, tryToDevice)
788+ // Tensor and Optional Tensor
774789EVALUE_DEFINE_TRY_TO (executorch::aten::Tensor, tryToTensor)
790+ EVALUE_DEFINE_TRY_TO (
791+ std::optional<executorch::aten::Tensor>,
792+ tryToOptional<executorch::aten::Tensor>)
793+
794+ // IntList and Optional IntList
775795EVALUE_DEFINE_TRY_TO (executorch::aten::ArrayRef<int64_t >, tryToIntList)
796+ EVALUE_DEFINE_TRY_TO (
797+ std::optional<executorch::aten::ArrayRef<int64_t >>,
798+ tryToOptional<executorch::aten::ArrayRef<int64_t >>)
799+
800+ // DoubleList and Optional DoubleList
776801EVALUE_DEFINE_TRY_TO (executorch::aten::ArrayRef<double >, tryToDoubleList)
802+ EVALUE_DEFINE_TRY_TO (
803+ std::optional<executorch::aten::ArrayRef<double >>,
804+ tryToOptional<executorch::aten::ArrayRef<double >>)
805+
806+ // BoolList and Optional BoolList
777807EVALUE_DEFINE_TRY_TO (executorch::aten::ArrayRef<bool >, tryToBoolList)
808+ EVALUE_DEFINE_TRY_TO (
809+ std::optional<executorch::aten::ArrayRef<bool >>,
810+ tryToOptional<executorch::aten::ArrayRef<bool >>)
811+
812+ // TensorList and Optional TensorList
778813EVALUE_DEFINE_TRY_TO (
779814 executorch::aten::ArrayRef<executorch::aten::Tensor>,
780815 tryToTensorList)
816+ EVALUE_DEFINE_TRY_TO (
817+ std::optional<executorch::aten::ArrayRef<executorch::aten::Tensor>>,
818+ tryToOptional<executorch::aten::ArrayRef<executorch::aten::Tensor>>)
819+
820+ // List of Optional Tensor
781821EVALUE_DEFINE_TRY_TO (
782822 executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>,
783823 tryToListOptionalTensor)
@@ -794,6 +834,23 @@ executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
794834 return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size ()};
795835}
796836
837+ template <typename T>
838+ Result<executorch::aten::ArrayRef<T>> BoxedEvalueList<T>::tryGet() const {
839+ for (typename executorch::aten::ArrayRef<T>::size_type i = 0 ;
840+ i < wrapped_vals_.size ();
841+ i++) {
842+ if (wrapped_vals_[i] == nullptr ) {
843+ return Error::InvalidState;
844+ }
845+ auto r = wrapped_vals_[i]->template tryTo <T>();
846+ if (!r.ok ()) {
847+ return r.error ();
848+ }
849+ unwrapped_vals_[i] = std::move (r.get ());
850+ }
851+ return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size ()};
852+ }
853+
797854} // namespace runtime
798855} // namespace executorch
799856
0 commit comments