-
Notifications
You must be signed in to change notification settings - Fork 1k
Add EValue::tryTo<T>() for all EValue payload types #19036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,3 +1,3 @@ | ||||||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||||||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||||||||||||||||||||||||||||||||
| * All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -8,6 +8,7 @@ | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||
| #include <executorch/runtime/core/exec_aten/exec_aten.h> | ||||||||||||||||||||||||||||||||||||||||||
| #include <executorch/runtime/core/result.h> | ||||||||||||||||||||||||||||||||||||||||||
| #include <executorch/runtime/core/tag.h> | ||||||||||||||||||||||||||||||||||||||||||
| #include <executorch/runtime/platform/assert.h> | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -193,6 +194,13 @@ | |||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_int; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<int64_t> tryToInt() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isInt()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_int; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Double Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(double d) : tag(Tag::Double) { | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_double = d; | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -207,6 +215,13 @@ | |||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_double; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<double> tryToDouble() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isDouble()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_double; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Bool Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(bool b) : tag(Tag::Bool) { | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_bool = b; | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -221,6 +236,13 @@ | |||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_bool; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<bool> tryToBool() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isBool()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_bool; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Scalar Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /// Construct an EValue using the implicit value of a Scalar. | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(executorch::aten::Scalar s) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -256,6 +278,19 @@ | |||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::Scalar> tryToScalar() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (isDouble()) { | ||||||||||||||||||||||||||||||||||||||||||
| return executorch::aten::Scalar(payload.copyable_union.as_double); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (isInt()) { | ||||||||||||||||||||||||||||||||||||||||||
| return executorch::aten::Scalar(payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (isBool()) { | ||||||||||||||||||||||||||||||||||||||||||
| return executorch::aten::Scalar(payload.copyable_union.as_bool); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Tensor Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) { | ||||||||||||||||||||||||||||||||||||||||||
| // When built in aten mode, at::Tensor has a non trivial constructor | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -305,6 +340,13 @@ | |||||||||||||||||||||||||||||||||||||||||
| return payload.as_tensor; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::Tensor> tryToTensor() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isTensor()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return payload.as_tensor; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** String Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) { | ||||||||||||||||||||||||||||||||||||||||||
| ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null"); | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -325,6 +367,18 @@ | |||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_string_ptr->size()); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<std::string_view> tryToString() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isString()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (payload.copyable_union.as_string_ptr == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidState; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return std::string_view( | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_string_ptr->data(), | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_string_ptr->size()); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Int List Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) { | ||||||||||||||||||||||||||||||||||||||||||
| ET_CHECK_MSG( | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -344,6 +398,16 @@ | |||||||||||||||||||||||||||||||||||||||||
| return (payload.copyable_union.as_int_list_ptr)->get(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::ArrayRef<int64_t>> tryToIntList() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isIntList()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (payload.copyable_union.as_int_list_ptr == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidState; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return (payload.copyable_union.as_int_list_ptr)->get(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Bool List Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b) | ||||||||||||||||||||||||||||||||||||||||||
| : tag(Tag::ListBool) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -363,6 +427,16 @@ | |||||||||||||||||||||||||||||||||||||||||
| return *(payload.copyable_union.as_bool_list_ptr); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::ArrayRef<bool>> tryToBoolList() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isBoolList()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (payload.copyable_union.as_bool_list_ptr == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidState; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return *(payload.copyable_union.as_bool_list_ptr); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Double List Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(executorch::aten::ArrayRef<double>* d) | ||||||||||||||||||||||||||||||||||||||||||
| : tag(Tag::ListDouble) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -382,6 +456,16 @@ | |||||||||||||||||||||||||||||||||||||||||
| return *(payload.copyable_union.as_double_list_ptr); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::ArrayRef<double>> tryToDoubleList() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isDoubleList()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (payload.copyable_union.as_double_list_ptr == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidState; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return *(payload.copyable_union.as_double_list_ptr); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Tensor List Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t) | ||||||||||||||||||||||||||||||||||||||||||
| : tag(Tag::ListTensor) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -402,6 +486,17 @@ | |||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_tensor_list_ptr->get(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::ArrayRef<executorch::aten::Tensor>> tryToTensorList() | ||||||||||||||||||||||||||||||||||||||||||
| const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isTensorList()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (payload.copyable_union.as_tensor_list_ptr == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidState; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_tensor_list_ptr->get(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** List Optional Tensor Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| /*implicit*/ EValue( | ||||||||||||||||||||||||||||||||||||||||||
| BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t) | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -426,26 +521,60 @@ | |||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_list_optional_tensor_ptr->get(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>> | ||||||||||||||||||||||||||||||||||||||||||
| tryToListOptionalTensor() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isListOptionalTensor()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidState; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return payload.copyable_union.as_list_optional_tensor_ptr->get(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** ScalarType Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| executorch::aten::ScalarType toScalarType() const { | ||||||||||||||||||||||||||||||||||||||||||
| ET_CHECK_MSG(isInt(), "EValue is not a ScalarType."); | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<executorch::aten::ScalarType>( | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::ScalarType> tryToScalarType() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isInt()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<executorch::aten::ScalarType>( | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** MemoryFormat Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| executorch::aten::MemoryFormat toMemoryFormat() const { | ||||||||||||||||||||||||||||||||||||||||||
| ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat."); | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<executorch::aten::MemoryFormat>( | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::MemoryFormat> tryToMemoryFormat() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isInt()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<executorch::aten::MemoryFormat>( | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Layout Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| executorch::aten::Layout toLayout() const { | ||||||||||||||||||||||||||||||||||||||||||
| ET_CHECK_MSG(isInt(), "EValue is not a Layout."); | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::Layout> tryToLayout() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isInt()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /****** Device Type ******/ | ||||||||||||||||||||||||||||||||||||||||||
| executorch::aten::Device toDevice() const { | ||||||||||||||||||||||||||||||||||||||||||
| ET_CHECK_MSG(isInt(), "EValue is not a Device."); | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -455,13 +584,32 @@ | |||||||||||||||||||||||||||||||||||||||||
| -1); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Result<executorch::aten::Device> tryToDevice() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (!isInt()) { | ||||||||||||||||||||||||||||||||||||||||||
| return Error::InvalidType; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return executorch::aten::Device( | ||||||||||||||||||||||||||||||||||||||||||
| static_cast<executorch::aten::DeviceType>( | ||||||||||||||||||||||||||||||||||||||||||
| payload.copyable_union.as_int), | ||||||||||||||||||||||||||||||||||||||||||
| -1); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||
| T to() &&; | ||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||
| typename internal::evalue_to_const_ref_overload_return<T>::type to() const&; | ||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||
| typename internal::evalue_to_ref_overload_return<T>::type to() &; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||||||
| * Result-returning equivalent of `to<T>()`. Returns `Error::InvalidType` on | ||||||||||||||||||||||||||||||||||||||||||
| * tag mismatch instead of aborting, so callers processing untrusted EValues | ||||||||||||||||||||||||||||||||||||||||||
| * (e.g., from a `.pte`) can surface the error rather than terminate. | ||||||||||||||||||||||||||||||||||||||||||
| * Specializations are defined below via `EVALUE_DEFINE_TRY_TO`. | ||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||
| Result<T> tryTo() const; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||||||
| * Converts the EValue to an optional object that can represent both T and | ||||||||||||||||||||||||||||||||||||||||||
| * an uninitialized state. | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -474,6 +622,23 @@ | |||||||||||||||||||||||||||||||||||||||||
| return this->to<T>(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||||||||||||
| * Result-returning equivalent of `toOptional<T>()`. None maps to an empty | ||||||||||||||||||||||||||||||||||||||||||
| * optional; any other tag that doesn't match T propagates `tryTo<T>()`'s | ||||||||||||||||||||||||||||||||||||||||||
| * error (`Error::InvalidType`). | ||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||||||||||||
| inline Result<std::optional<T>> tryToOptional() const { | ||||||||||||||||||||||||||||||||||||||||||
| if (this->isNone()) { | ||||||||||||||||||||||||||||||||||||||||||
| return std::optional<T>(executorch::aten::nullopt); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| auto r = this->tryTo<T>(); | ||||||||||||||||||||||||||||||||||||||||||
| if (!r.ok()) { | ||||||||||||||||||||||||||||||||||||||||||
| return r.error(); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| return std::optional<T>(std::move(r.get())); | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+625
to
+639
|
||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||
| // Pre cond: the payload value has had its destructor called | ||||||||||||||||||||||||||||||||||||||||||
| void clearToNone() noexcept { | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -524,7 +689,7 @@ | |||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #define EVALUE_DEFINE_TO(T, method_name) \ | ||||||||||||||||||||||||||||||||||||||||||
| template <> \ | ||||||||||||||||||||||||||||||||||||||||||
| inline T EValue::to<T>()&& { \ | ||||||||||||||||||||||||||||||||||||||||||
| inline T EValue::to<T>() && { \ | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<T>(std::move(*this).method_name()); \ | ||||||||||||||||||||||||||||||||||||||||||
| } \ | ||||||||||||||||||||||||||||||||||||||||||
| template <> \ | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -538,7 +703,7 @@ | |||||||||||||||||||||||||||||||||||||||||
| template <> \ | ||||||||||||||||||||||||||||||||||||||||||
| inline ::executorch::runtime::internal::evalue_to_ref_overload_return< \ | ||||||||||||||||||||||||||||||||||||||||||
| T>::type \ | ||||||||||||||||||||||||||||||||||||||||||
| EValue::to<T>()& { \ | ||||||||||||||||||||||||||||||||||||||||||
| EValue::to<T>() & { \ | ||||||||||||||||||||||||||||||||||||||||||
| typedef ::executorch::runtime::internal::evalue_to_ref_overload_return< \ | ||||||||||||||||||||||||||||||||||||||||||
| T>::type return_type; \ | ||||||||||||||||||||||||||||||||||||||||||
| return static_cast<return_type>(this->method_name()); \ | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -591,6 +756,33 @@ | |||||||||||||||||||||||||||||||||||||||||
| toListOptionalTensor) | ||||||||||||||||||||||||||||||||||||||||||
| #undef EVALUE_DEFINE_TO | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #define EVALUE_DEFINE_TRY_TO(T, method_name) \ | ||||||||||||||||||||||||||||||||||||||||||
| template <> \ | ||||||||||||||||||||||||||||||||||||||||||
| inline Result<T> EValue::tryTo<T>() const { \ | ||||||||||||||||||||||||||||||||||||||||||
| return this->method_name(); \ | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(int64_t, tryToInt) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(bool, tryToBool) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(double, tryToDouble) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(std::string_view, tryToString) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<int64_t>, tryToIntList) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<double>, tryToDoubleList) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<bool>, tryToBoolList) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO( | ||||||||||||||||||||||||||||||||||||||||||
| executorch::aten::ArrayRef<executorch::aten::Tensor>, | ||||||||||||||||||||||||||||||||||||||||||
| tryToTensorList) | ||||||||||||||||||||||||||||||||||||||||||
| EVALUE_DEFINE_TRY_TO( | ||||||||||||||||||||||||||||||||||||||||||
| executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>, | ||||||||||||||||||||||||||||||||||||||||||
| tryToListOptionalTensor) | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
| tryToListOptionalTensor) | |
| tryToListOptionalTensor) | |
| #define EVALUE_DEFINE_TRY_TO_OPTIONAL(cpp_type) \ | |
| template <> \ | |
| inline Result<std::optional<cpp_type>> EValue::tryTo< \ | |
| std::optional<cpp_type>>() const { \ | |
| return tryToOptional<cpp_type>(); \ | |
| } | |
| EVALUE_DEFINE_TRY_TO_OPTIONAL(executorch::aten::Tensor) | |
| EVALUE_DEFINE_TRY_TO_OPTIONAL(executorch::aten::ArrayRef<int64_t>) | |
| EVALUE_DEFINE_TRY_TO_OPTIONAL(executorch::aten::ArrayRef<double>) | |
| EVALUE_DEFINE_TRY_TO_OPTIONAL(executorch::aten::ArrayRef<bool>) | |
| EVALUE_DEFINE_TRY_TO_OPTIONAL( | |
| executorch::aten::ArrayRef<executorch::aten::Tensor>) | |
| EVALUE_DEFINE_TRY_TO_OPTIONAL( | |
| executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>) | |
| #undef EVALUE_DEFINE_TRY_TO_OPTIONAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tryToTensor() returns a Tensor by value from a const EValue, which will typically incur a Tensor copy/refcount bump even when the caller only needs to inspect the tensor. If the goal is a safe, non-aborting accessor for performance-sensitive paths, consider returning a pointer type (e.g., Result<const Tensor*>) or adding an rvalue overload (tryToTensor() &&) that can move out, to avoid extra copies where possible.