|
8 | 8 |
|
9 | 9 | #include <executorch/runtime/executor/tensor_parser.h> |
10 | 10 |
|
| 11 | +#include <cstring> |
| 12 | + |
11 | 13 | #include <executorch/extension/data_loader/file_data_loader.h> |
12 | 14 | #include <executorch/runtime/core/exec_aten/exec_aten.h> |
13 | 15 | #include <executorch/runtime/core/tensor_layout.h> |
|
19 | 21 | using namespace ::testing; |
20 | 22 | using executorch::aten::ScalarType; |
21 | 23 | using executorch::aten::Tensor; |
| 24 | +using executorch::runtime::BoxedEvalueList; |
22 | 25 | using executorch::runtime::Error; |
23 | 26 | using executorch::runtime::EValue; |
24 | 27 | using executorch::runtime::FreeableBuffer; |
25 | 28 | using executorch::runtime::Program; |
26 | 29 | using executorch::runtime::Result; |
27 | 30 | using executorch::runtime::Span; |
28 | 31 | using executorch::runtime::TensorLayout; |
| 32 | +using executorch::runtime::deserialization::parseListOptionalType; |
29 | 33 | using executorch::runtime::deserialization::parseTensor; |
| 34 | +using executorch::runtime::deserialization::parseTensorList; |
30 | 35 | using executorch::runtime::deserialization::validateTensorLayout; |
31 | 36 | using executorch::runtime::testing::ManagedMemoryManager; |
32 | 37 | using torch::executor::util::FileDataLoader; |
@@ -223,3 +228,60 @@ TEST(ValidateTensorLayoutTest, DimOrderSizeMismatchIsRejected) { |
223 | 228 | EXPECT_EQ( |
224 | 229 | validateTensorLayout(s_tensor, layout.get()), Error::InvalidExternalData); |
225 | 230 | } |
| 231 | + |
| 232 | +// Helper to construct a flatbuffers::Vector<int32_t> from raw data. |
| 233 | +// FlatBuffer vectors are stored as [uint32_t length][T elements...]. |
| 234 | +namespace { |
| 235 | +struct FlatVectorInt32 { |
| 236 | + static const flatbuffers::Vector<int32_t>* create( |
| 237 | + std::vector<uint8_t>& buf, |
| 238 | + const std::vector<int32_t>& elements) { |
| 239 | + buf.resize(sizeof(uint32_t) + elements.size() * sizeof(int32_t)); |
| 240 | + uint32_t len = static_cast<uint32_t>(elements.size()); |
| 241 | + memcpy(buf.data(), &len, sizeof(len)); |
| 242 | + if (!elements.empty()) { |
| 243 | + memcpy( |
| 244 | + buf.data() + sizeof(uint32_t), |
| 245 | + elements.data(), |
| 246 | + elements.size() * sizeof(int32_t)); |
| 247 | + } |
| 248 | + return reinterpret_cast<const flatbuffers::Vector<int32_t>*>(buf.data()); |
| 249 | + } |
| 250 | +}; |
| 251 | +} // namespace |
| 252 | + |
| 253 | +// parseTensorList should return an error when the EValue at the given index |
| 254 | +// is not a Tensor, instead of aborting. |
| 255 | +TEST_F(TensorParserTest, ParseTensorListRejectsNonTensorEValue) { |
| 256 | + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); |
| 257 | + |
| 258 | + // Create an EValue array with a non-Tensor value at index 0. |
| 259 | + EValue values[2]; |
| 260 | + values[0] = EValue(static_cast<int64_t>(42)); // Int, not Tensor |
| 261 | + values[1] = EValue(static_cast<int64_t>(7)); |
| 262 | + |
| 263 | + // Create a vector with index 0 (pointing to the Int EValue). |
| 264 | + std::vector<uint8_t> vec_buf; |
| 265 | + auto* indices = FlatVectorInt32::create(vec_buf, {0}); |
| 266 | + |
| 267 | + auto result = parseTensorList(indices, values, 2, &mmm.get()); |
| 268 | + EXPECT_EQ(result.error(), Error::InvalidType); |
| 269 | +} |
| 270 | + |
| 271 | +// parseListOptionalType should return an error when the EValue at the given |
| 272 | +// index is neither None nor the expected type. |
| 273 | +TEST_F(TensorParserTest, ParseListOptionalTypeRejectsWrongType) { |
| 274 | + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); |
| 275 | + |
| 276 | + // Create an EValue array with a non-Tensor, non-None value at index 0. |
| 277 | + EValue values[2]; |
| 278 | + values[0] = EValue(static_cast<int64_t>(42)); // Int, not Tensor or None |
| 279 | + values[1] = EValue(static_cast<int64_t>(7)); |
| 280 | + |
| 281 | + // Create a vector with index 0 (pointing to the Int EValue). |
| 282 | + std::vector<uint8_t> vec_buf; |
| 283 | + auto* indices = FlatVectorInt32::create(vec_buf, {0}); |
| 284 | + |
| 285 | + auto result = parseListOptionalType<Tensor>(indices, values, 2, &mmm.get()); |
| 286 | + EXPECT_EQ(result.error(), Error::InvalidType); |
| 287 | +} |
0 commit comments