Skip to content

Commit 3bbed22

Browse files
author
Github Executorch
committed
Update
[ghstack-poisoned]
1 parent c86a94e commit 3bbed22

3 files changed

Lines changed: 154 additions & 4 deletions

File tree

runtime/core/evalue.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,30 @@ BoxedEvalueList<std::optional<executorch::aten::Tensor>>::get() const {
2727
return executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>{
2828
unwrapped_vals_, wrapped_vals_.size()};
2929
}
30+
31+
// Specialization note: unlike the generic tryGet, a null wrapped_vals_[i]
32+
// here is a valid encoding of None (matching the get() specialization above,
33+
// which mirrors parseListOptionalType's "absent index" convention). Only an
34+
// element whose tag is neither None nor Tensor is treated as an error.
35+
template <>
36+
Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
37+
BoxedEvalueList<std::optional<executorch::aten::Tensor>>::tryGet() const {
38+
for (typename executorch::aten::ArrayRef<
39+
std::optional<executorch::aten::Tensor>>::size_type i = 0;
40+
i < wrapped_vals_.size();
41+
i++) {
42+
if (wrapped_vals_[i] == nullptr) {
43+
unwrapped_vals_[i] = std::nullopt;
44+
continue;
45+
}
46+
auto r = wrapped_vals_[i]->tryToOptional<executorch::aten::Tensor>();
47+
if (!r.ok()) {
48+
return r.error();
49+
}
50+
unwrapped_vals_[i] = std::move(r.get());
51+
}
52+
return executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>{
53+
unwrapped_vals_, wrapped_vals_.size()};
54+
}
3055
} // namespace runtime
3156
} // namespace executorch

runtime/core/evalue.h

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,16 @@ class BoxedEvalueList {
7272
*/
7373
executorch::aten::ArrayRef<T> get() const;
7474

75+
/**
76+
* Result-returning counterpart of get(). Validates each wrapped EValue's
77+
* tag before materializing; returns Error::InvalidType if any element's
78+
* tag does not match T and Error::InvalidState if any element pointer is
79+
* null. Use this when materializing lists from untrusted .pte data so that
80+
* a malformed program cannot force a process abort inside to<T>() /
81+
* ET_CHECK.
82+
*/
83+
Result<executorch::aten::ArrayRef<T>> tryGet() const;
84+
7585
/**
7686
* Destroys the unwrapped elements without re-dereferencing wrapped_vals_.
7787
* This is safe to call during EValue destruction because it does not
@@ -108,6 +118,10 @@ template <>
108118
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
109119
BoxedEvalueList<std::optional<executorch::aten::Tensor>>::get() const;
110120

121+
template <>
122+
Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
123+
BoxedEvalueList<std::optional<executorch::aten::Tensor>>::tryGet() const;
124+
111125
// Aggregate typing system similar to IValue only slimmed down with less
112126
// functionality, no dependencies on atomic, and fewer supported types to better
113127
// suit embedded systems (ie no intrusive ptr)
@@ -405,7 +419,7 @@ struct EValue {
405419
if (payload.copyable_union.as_int_list_ptr == nullptr) {
406420
return Error::InvalidState;
407421
}
408-
return (payload.copyable_union.as_int_list_ptr)->get();
422+
return (payload.copyable_union.as_int_list_ptr)->tryGet();
409423
}
410424

411425
/****** Bool List Type ******/
@@ -494,7 +508,7 @@ struct EValue {
494508
if (payload.copyable_union.as_tensor_list_ptr == nullptr) {
495509
return Error::InvalidState;
496510
}
497-
return payload.copyable_union.as_tensor_list_ptr->get();
511+
return payload.copyable_union.as_tensor_list_ptr->tryGet();
498512
}
499513

500514
/****** List Optional Tensor Type ******/
@@ -529,7 +543,7 @@ struct EValue {
529543
if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) {
530544
return Error::InvalidState;
531545
}
532-
return payload.copyable_union.as_list_optional_tensor_ptr->get();
546+
return payload.copyable_union.as_list_optional_tensor_ptr->tryGet();
533547
}
534548

535549
/****** ScalarType Type ******/
@@ -630,7 +644,7 @@ struct EValue {
630644
template <typename T>
631645
inline Result<std::optional<T>> tryToOptional() const {
632646
if (this->isNone()) {
633-
return std::optional<T>(executorch::aten::nullopt);
647+
return std::optional<T>(std::nullopt);
634648
}
635649
auto r = this->tryTo<T>();
636650
if (!r.ok()) {
@@ -771,13 +785,39 @@ EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType)
771785
EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat)
772786
EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout)
773787
EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice)
788+
// Tensor and Optional Tensor
774789
EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor)
790+
EVALUE_DEFINE_TRY_TO(
791+
std::optional<executorch::aten::Tensor>,
792+
tryToOptional<executorch::aten::Tensor>)
793+
794+
// IntList and Optional IntList
775795
EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<int64_t>, tryToIntList)
796+
EVALUE_DEFINE_TRY_TO(
797+
std::optional<executorch::aten::ArrayRef<int64_t>>,
798+
tryToOptional<executorch::aten::ArrayRef<int64_t>>)
799+
800+
// DoubleList and Optional DoubleList
776801
EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<double>, tryToDoubleList)
802+
EVALUE_DEFINE_TRY_TO(
803+
std::optional<executorch::aten::ArrayRef<double>>,
804+
tryToOptional<executorch::aten::ArrayRef<double>>)
805+
806+
// BoolList and Optional BoolList
777807
EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<bool>, tryToBoolList)
808+
EVALUE_DEFINE_TRY_TO(
809+
std::optional<executorch::aten::ArrayRef<bool>>,
810+
tryToOptional<executorch::aten::ArrayRef<bool>>)
811+
812+
// TensorList and Optional TensorList
778813
EVALUE_DEFINE_TRY_TO(
779814
executorch::aten::ArrayRef<executorch::aten::Tensor>,
780815
tryToTensorList)
816+
EVALUE_DEFINE_TRY_TO(
817+
std::optional<executorch::aten::ArrayRef<executorch::aten::Tensor>>,
818+
tryToOptional<executorch::aten::ArrayRef<executorch::aten::Tensor>>)
819+
820+
// List of Optional Tensor
781821
EVALUE_DEFINE_TRY_TO(
782822
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>,
783823
tryToListOptionalTensor)
@@ -794,6 +834,23 @@ executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
794834
return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
795835
}
796836

837+
template <typename T>
838+
Result<executorch::aten::ArrayRef<T>> BoxedEvalueList<T>::tryGet() const {
839+
for (typename executorch::aten::ArrayRef<T>::size_type i = 0;
840+
i < wrapped_vals_.size();
841+
i++) {
842+
if (wrapped_vals_[i] == nullptr) {
843+
return Error::InvalidState;
844+
}
845+
auto r = wrapped_vals_[i]->template tryTo<T>();
846+
if (!r.ok()) {
847+
return r.error();
848+
}
849+
unwrapped_vals_[i] = std::move(r.get());
850+
}
851+
return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
852+
}
853+
797854
} // namespace runtime
798855
} // namespace executorch
799856

runtime/core/test/evalue_test.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,57 @@ TEST_F(EValueTest, BoxedEvalueList) {
214214
EXPECT_EQ(unwrapped[2], 3);
215215
}
216216

217+
TEST_F(EValueTest, BoxedEvalueListTryGetSuccess) {
218+
EValue values[3] = {
219+
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
220+
EValue* values_p[3] = {&values[0], &values[1], &values[2]};
221+
int64_t storage[3] = {0, 0, 0};
222+
BoxedEvalueList<int64_t> x{values_p, storage, 3};
223+
auto result = x.tryGet();
224+
EXPECT_TRUE(result.ok());
225+
EXPECT_EQ(result->size(), 3);
226+
EXPECT_EQ((*result)[0], 1);
227+
EXPECT_EQ((*result)[2], 3);
228+
}
229+
230+
TEST_F(EValueTest, BoxedEvalueListTryGetWrongElementTag) {
231+
// Second element is a Double, not an Int; tryGet should reject it rather
232+
// than abort inside to<int64_t>().
233+
EValue values[3] = {EValue((int64_t)1), EValue(3.14), EValue((int64_t)3)};
234+
EValue* values_p[3] = {&values[0], &values[1], &values[2]};
235+
int64_t storage[3] = {0, 0, 0};
236+
BoxedEvalueList<int64_t> x{values_p, storage, 3};
237+
auto result = x.tryGet();
238+
EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType);
239+
}
240+
241+
TEST_F(EValueTest, BoxedEvalueListTryGetNullElement) {
242+
// A null wrapped pointer is a malformed program for non-optional lists;
243+
// tryGet reports InvalidState rather than aborting inside ET_CHECK.
244+
EValue a((int64_t)1);
245+
EValue c((int64_t)3);
246+
EValue* values_p[3] = {&a, nullptr, &c};
247+
int64_t storage[3] = {0, 0, 0};
248+
BoxedEvalueList<int64_t> x{values_p, storage, 3};
249+
auto result = x.tryGet();
250+
EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidState);
251+
}
252+
253+
TEST_F(EValueTest, BoxedEvalueListTryGetOptionalTensorNullIsNone) {
254+
// For the optional<Tensor> specialization, a null wrapped pointer is a
255+
// valid None encoding (matches parseListOptionalType), not an error.
256+
EValue a;
257+
EValue* values_p[2] = {&a, nullptr};
258+
std::optional<executorch::aten::Tensor> storage[2];
259+
BoxedEvalueList<std::optional<executorch::aten::Tensor>> x{
260+
values_p, storage, 2};
261+
auto result = x.tryGet();
262+
EXPECT_TRUE(result.ok());
263+
EXPECT_EQ(result->size(), 2);
264+
EXPECT_FALSE((*result)[0].has_value());
265+
EXPECT_FALSE((*result)[1].has_value());
266+
}
267+
217268
TEST_F(EValueTest, toOptionalTensorList) {
218269
// create list, empty evalue ctor gets tag::None
219270
EValue values[2] = {EValue(), EValue()};
@@ -602,3 +653,20 @@ TEST_F(EValueTest, TryToOptionalIntTypeMismatch) {
602653
auto result = e.tryToOptional<int64_t>();
603654
EXPECT_EQ(result.error(), executorch::runtime::Error::InvalidType);
604655
}
656+
657+
// Verify tryTo<std::optional<T>>() specializations match tryToOptional<T>()
658+
// semantics, mirroring the to<std::optional<T>>() specializations of to<T>().
659+
TEST_F(EValueTest, TryToTemplateOptionalIntSuccess) {
660+
EValue e(static_cast<int64_t>(42));
661+
auto result = e.tryTo<std::optional<int64_t>>();
662+
EXPECT_TRUE(result.ok());
663+
EXPECT_TRUE(result->has_value());
664+
EXPECT_EQ(result->value(), 42);
665+
}
666+
667+
TEST_F(EValueTest, TryToTemplateOptionalTensorNone) {
668+
EValue e;
669+
auto result = e.tryTo<std::optional<executorch::aten::Tensor>>();
670+
EXPECT_TRUE(result.ok());
671+
EXPECT_FALSE(result->has_value());
672+
}

0 commit comments

Comments
 (0)