Skip to content

Commit 6e4eab1

Browse files
author
Github Executorch
committed
Update
[ghstack-poisoned]
1 parent 2d53535 commit 6e4eab1

2 files changed

Lines changed: 422 additions & 2 deletions

File tree

runtime/core/evalue.h

Lines changed: 194 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#pragma once
1010
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
#include <executorch/runtime/core/result.h>
1112
#include <executorch/runtime/core/tag.h>
1213
#include <executorch/runtime/platform/assert.h>
1314

@@ -193,6 +194,13 @@ struct EValue {
193194
return payload.copyable_union.as_int;
194195
}
195196

197+
Result<int64_t> tryToInt() const {
198+
if (!isInt()) {
199+
return Error::InvalidType;
200+
}
201+
return payload.copyable_union.as_int;
202+
}
203+
196204
/****** Double Type ******/
197205
/*implicit*/ EValue(double d) : tag(Tag::Double) {
198206
payload.copyable_union.as_double = d;
@@ -207,6 +215,13 @@ struct EValue {
207215
return payload.copyable_union.as_double;
208216
}
209217

218+
Result<double> tryToDouble() const {
219+
if (!isDouble()) {
220+
return Error::InvalidType;
221+
}
222+
return payload.copyable_union.as_double;
223+
}
224+
210225
/****** Bool Type ******/
211226
/*implicit*/ EValue(bool b) : tag(Tag::Bool) {
212227
payload.copyable_union.as_bool = b;
@@ -221,6 +236,13 @@ struct EValue {
221236
return payload.copyable_union.as_bool;
222237
}
223238

239+
Result<bool> tryToBool() const {
240+
if (!isBool()) {
241+
return Error::InvalidType;
242+
}
243+
return payload.copyable_union.as_bool;
244+
}
245+
224246
/****** Scalar Type ******/
225247
/// Construct an EValue using the implicit value of a Scalar.
226248
/*implicit*/ EValue(executorch::aten::Scalar s) {
@@ -256,6 +278,19 @@ struct EValue {
256278
}
257279
}
258280

281+
Result<executorch::aten::Scalar> tryToScalar() const {
282+
if (isDouble()) {
283+
return executorch::aten::Scalar(payload.copyable_union.as_double);
284+
}
285+
if (isInt()) {
286+
return executorch::aten::Scalar(payload.copyable_union.as_int);
287+
}
288+
if (isBool()) {
289+
return executorch::aten::Scalar(payload.copyable_union.as_bool);
290+
}
291+
return Error::InvalidType;
292+
}
293+
259294
/****** Tensor Type ******/
260295
/*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) {
261296
// When built in aten mode, at::Tensor has a non trivial constructor
@@ -305,6 +340,13 @@ struct EValue {
305340
return payload.as_tensor;
306341
}
307342

343+
Result<executorch::aten::Tensor> tryToTensor() const {
344+
if (!isTensor()) {
345+
return Error::InvalidType;
346+
}
347+
return payload.as_tensor;
348+
}
349+
308350
/****** String Type ******/
309351
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
310352
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
@@ -325,6 +367,18 @@ struct EValue {
325367
payload.copyable_union.as_string_ptr->size());
326368
}
327369

370+
Result<std::string_view> tryToString() const {
371+
if (!isString()) {
372+
return Error::InvalidType;
373+
}
374+
if (payload.copyable_union.as_string_ptr == nullptr) {
375+
return Error::InvalidState;
376+
}
377+
return std::string_view(
378+
payload.copyable_union.as_string_ptr->data(),
379+
payload.copyable_union.as_string_ptr->size());
380+
}
381+
328382
/****** Int List Type ******/
329383
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
330384
ET_CHECK_MSG(
@@ -344,6 +398,16 @@ struct EValue {
344398
return (payload.copyable_union.as_int_list_ptr)->get();
345399
}
346400

401+
Result<executorch::aten::ArrayRef<int64_t>> tryToIntList() const {
402+
if (!isIntList()) {
403+
return Error::InvalidType;
404+
}
405+
if (payload.copyable_union.as_int_list_ptr == nullptr) {
406+
return Error::InvalidState;
407+
}
408+
return (payload.copyable_union.as_int_list_ptr)->get();
409+
}
410+
347411
/****** Bool List Type ******/
348412
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
349413
: tag(Tag::ListBool) {
@@ -363,6 +427,16 @@ struct EValue {
363427
return *(payload.copyable_union.as_bool_list_ptr);
364428
}
365429

430+
Result<executorch::aten::ArrayRef<bool>> tryToBoolList() const {
431+
if (!isBoolList()) {
432+
return Error::InvalidType;
433+
}
434+
if (payload.copyable_union.as_bool_list_ptr == nullptr) {
435+
return Error::InvalidState;
436+
}
437+
return *(payload.copyable_union.as_bool_list_ptr);
438+
}
439+
366440
/****** Double List Type ******/
367441
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
368442
: tag(Tag::ListDouble) {
@@ -382,6 +456,16 @@ struct EValue {
382456
return *(payload.copyable_union.as_double_list_ptr);
383457
}
384458

459+
Result<executorch::aten::ArrayRef<double>> tryToDoubleList() const {
460+
if (!isDoubleList()) {
461+
return Error::InvalidType;
462+
}
463+
if (payload.copyable_union.as_double_list_ptr == nullptr) {
464+
return Error::InvalidState;
465+
}
466+
return *(payload.copyable_union.as_double_list_ptr);
467+
}
468+
385469
/****** Tensor List Type ******/
386470
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
387471
: tag(Tag::ListTensor) {
@@ -402,6 +486,17 @@ struct EValue {
402486
return payload.copyable_union.as_tensor_list_ptr->get();
403487
}
404488

489+
Result<executorch::aten::ArrayRef<executorch::aten::Tensor>> tryToTensorList()
490+
const {
491+
if (!isTensorList()) {
492+
return Error::InvalidType;
493+
}
494+
if (payload.copyable_union.as_tensor_list_ptr == nullptr) {
495+
return Error::InvalidState;
496+
}
497+
return payload.copyable_union.as_tensor_list_ptr->get();
498+
}
499+
405500
/****** List Optional Tensor Type ******/
406501
/*implicit*/ EValue(
407502
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
@@ -426,26 +521,60 @@ struct EValue {
426521
return payload.copyable_union.as_list_optional_tensor_ptr->get();
427522
}
428523

524+
Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
525+
tryToListOptionalTensor() const {
526+
if (!isListOptionalTensor()) {
527+
return Error::InvalidType;
528+
}
529+
if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) {
530+
return Error::InvalidState;
531+
}
532+
return payload.copyable_union.as_list_optional_tensor_ptr->get();
533+
}
534+
429535
/****** ScalarType Type ******/
430536
executorch::aten::ScalarType toScalarType() const {
431537
ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
432538
return static_cast<executorch::aten::ScalarType>(
433539
payload.copyable_union.as_int);
434540
}
435541

542+
Result<executorch::aten::ScalarType> tryToScalarType() const {
543+
if (!isInt()) {
544+
return Error::InvalidType;
545+
}
546+
return static_cast<executorch::aten::ScalarType>(
547+
payload.copyable_union.as_int);
548+
}
549+
436550
/****** MemoryFormat Type ******/
437551
executorch::aten::MemoryFormat toMemoryFormat() const {
438552
ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
439553
return static_cast<executorch::aten::MemoryFormat>(
440554
payload.copyable_union.as_int);
441555
}
442556

557+
Result<executorch::aten::MemoryFormat> tryToMemoryFormat() const {
558+
if (!isInt()) {
559+
return Error::InvalidType;
560+
}
561+
return static_cast<executorch::aten::MemoryFormat>(
562+
payload.copyable_union.as_int);
563+
}
564+
443565
/****** Layout Type ******/
444566
executorch::aten::Layout toLayout() const {
445567
ET_CHECK_MSG(isInt(), "EValue is not a Layout.");
446568
return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int);
447569
}
448570

571+
Result<executorch::aten::Layout> tryToLayout() const {
572+
if (!isInt()) {
573+
return Error::InvalidType;
574+
}
575+
return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int);
576+
}
577+
449578
/****** Device Type ******/
450579
executorch::aten::Device toDevice() const {
451580
ET_CHECK_MSG(isInt(), "EValue is not a Device.");
@@ -455,13 +584,32 @@ struct EValue {
455584
-1);
456585
}
457586

587+
Result<executorch::aten::Device> tryToDevice() const {
588+
if (!isInt()) {
589+
return Error::InvalidType;
590+
}
591+
return executorch::aten::Device(
592+
static_cast<executorch::aten::DeviceType>(
593+
payload.copyable_union.as_int),
594+
-1);
595+
}
596+
458597
template <typename T>
459598
T to() &&;
460599
template <typename T>
461600
typename internal::evalue_to_const_ref_overload_return<T>::type to() const&;
462601
template <typename T>
463602
typename internal::evalue_to_ref_overload_return<T>::type to() &;
464603

604+
/**
605+
* Result-returning equivalent of `to<T>()`. Returns `Error::InvalidType` on
606+
* tag mismatch instead of aborting, so callers processing untrusted EValues
607+
* (e.g., from a `.pte`) can surface the error rather than terminate.
608+
* Specializations are defined below via `EVALUE_DEFINE_TRY_TO`.
609+
*/
610+
template <typename T>
611+
Result<T> tryTo() const;
612+
465613
/**
466614
* Converts the EValue to an optional object that can represent both T and
467615
* an uninitialized state.
@@ -474,6 +622,23 @@ struct EValue {
474622
return this->to<T>();
475623
}
476624

625+
/**
626+
* Result-returning equivalent of `toOptional<T>()`. None maps to an empty
627+
* optional; any other tag that doesn't match T propagates `tryTo<T>()`'s
628+
* error (`Error::InvalidType`).
629+
*/
630+
template <typename T>
631+
inline Result<std::optional<T>> tryToOptional() const {
632+
if (this->isNone()) {
633+
return std::optional<T>(executorch::aten::nullopt);
634+
}
635+
auto r = this->tryTo<T>();
636+
if (!r.ok()) {
637+
return r.error();
638+
}
639+
return std::optional<T>(std::move(r.get()));
640+
}
641+
477642
private:
478643
// Pre cond: the payload value has had its destructor called
479644
void clearToNone() noexcept {
@@ -524,7 +689,7 @@ struct EValue {
524689

525690
#define EVALUE_DEFINE_TO(T, method_name) \
526691
template <> \
527-
inline T EValue::to<T>()&& { \
692+
inline T EValue::to<T>() && { \
528693
return static_cast<T>(std::move(*this).method_name()); \
529694
} \
530695
template <> \
@@ -538,7 +703,7 @@ struct EValue {
538703
template <> \
539704
inline ::executorch::runtime::internal::evalue_to_ref_overload_return< \
540705
T>::type \
541-
EValue::to<T>()& { \
706+
EValue::to<T>() & { \
542707
typedef ::executorch::runtime::internal::evalue_to_ref_overload_return< \
543708
T>::type return_type; \
544709
return static_cast<return_type>(this->method_name()); \
@@ -591,6 +756,33 @@ EVALUE_DEFINE_TO(
591756
toListOptionalTensor)
592757
#undef EVALUE_DEFINE_TO
593758

759+
#define EVALUE_DEFINE_TRY_TO(T, method_name) \
760+
template <> \
761+
inline Result<T> EValue::tryTo<T>() const { \
762+
return this->method_name(); \
763+
}
764+
765+
EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar)
766+
EVALUE_DEFINE_TRY_TO(int64_t, tryToInt)
767+
EVALUE_DEFINE_TRY_TO(bool, tryToBool)
768+
EVALUE_DEFINE_TRY_TO(double, tryToDouble)
769+
EVALUE_DEFINE_TRY_TO(std::string_view, tryToString)
770+
EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType)
771+
EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat)
772+
EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout)
773+
EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice)
774+
EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor)
775+
EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<int64_t>, tryToIntList)
776+
EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<double>, tryToDoubleList)
777+
EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<bool>, tryToBoolList)
778+
EVALUE_DEFINE_TRY_TO(
779+
executorch::aten::ArrayRef<executorch::aten::Tensor>,
780+
tryToTensorList)
781+
EVALUE_DEFINE_TRY_TO(
782+
executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>,
783+
tryToListOptionalTensor)
784+
#undef EVALUE_DEFINE_TRY_TO
785+
594786
template <typename T>
595787
executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
596788
for (typename executorch::aten::ArrayRef<T>::size_type i = 0;

0 commit comments

Comments
 (0)