Skip to content

Commit 90e6e4c

Browse files
rascaniclaude
andauthored
Validate dim_order is a permutation in dim_order_to_stride (#17314)
### Summary The validate_dim_order function only checked that values were in bounds, allowing invalid inputs like {0, 0, 0} to pass. This caused uninitialized memory access in dim_order_to_stride_nocheck. Fix by using a bitmask to detect duplicates. Also adds test fixture with runtime_init() for error logging and removes duplicate include. ### Test plan ``` ./test/run_oss_cpp_tests.sh ``` --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent b4fc4d3 commit 90e6e4c

2 files changed

Lines changed: 80 additions & 9 deletions

File tree

runtime/core/exec_aten/util/dim_order_util.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
#pragma once
1010

11-
#include <c10/util/irange.h>
1211
#include <cstdint>
1312
#include <cstdio>
1413
#include <cstring>
1514

1615
#include <c10/util/irange.h>
1716
#include <executorch/runtime/core/error.h>
17+
#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>
1818
#include <executorch/runtime/platform/assert.h>
1919
#include <executorch/runtime/platform/compiler.h>
2020

@@ -24,10 +24,22 @@ namespace runtime {
2424
namespace {
2525
template <typename DimOrderType>
2626
bool validate_dim_order(const DimOrderType* dim_order, const size_t dims) {
27+
static_assert(
28+
kTensorDimensionLimit <= 16,
29+
"Bitmask-based validation requires kTensorDimensionLimit <= 16");
30+
if (dims > kTensorDimensionLimit) {
31+
return false;
32+
}
33+
uint16_t seen = 0;
2734
for (const auto i : c10::irange(dims)) {
2835
if (dim_order[i] >= static_cast<DimOrderType>(dims)) {
2936
return false;
3037
}
38+
const uint16_t mask = 1u << dim_order[i];
39+
if (seen & mask) {
40+
return false;
41+
}
42+
seen |= mask;
3143
}
3244
return true;
3345
}
@@ -150,7 +162,7 @@ ET_NODISCARD inline Error dim_order_to_stride(
150162
ET_CHECK_OR_RETURN_ERROR(
151163
validate_dim_order(dim_order, dims),
152164
InvalidArgument,
153-
"Invalid dim order. One of the value is larger than the number of dims %zu",
165+
"Invalid dim order: values must be a permutation of [0, %zu)",
154166
dims);
155167

156168
dim_order_to_stride_nocheck(sizes, dim_order, dims, strides);

runtime/core/exec_aten/util/test/dim_order_util_test.cpp

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <numeric>
1313

1414
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/platform/runtime.h>
1516

1617
#include <gtest/gtest.h>
1718

@@ -21,6 +22,15 @@ using executorch::runtime::is_channels_last_dim_order;
2122
using executorch::runtime::is_contiguous_dim_order;
2223
using executorch::runtime::stride_to_dim_order;
2324

25+
class DimOrderUtilTest : public ::testing::Test {
26+
protected:
27+
void SetUp() override {
28+
// As some of these tests cause ET_LOG to be called, the PAL must be
29+
// initialized first by calling runtime_init();
30+
executorch::runtime::runtime_init();
31+
}
32+
};
33+
2434
namespace {
2535
void check_strides_eq(
2636
executorch::aten::ArrayRef<executorch::aten::StridesType> strides_a,
@@ -39,7 +49,7 @@ void check_dim_order_eq(
3949
}
4050
} // namespace
4151

42-
TEST(DimOrderUtilTest, DimOrderToStride) {
52+
TEST_F(DimOrderUtilTest, DimOrderToStride) {
4353
executorch::aten::SizesType sizes_1[1] = {5};
4454
executorch::aten::SizesType dim_order_1[1] = {0};
4555
executorch::aten::SizesType strides_1[1] = {0};
@@ -204,7 +214,7 @@ TEST(DimOrderUtilTest, DimOrderToStride) {
204214
check_strides_eq({strides_3_zero, 3}, {expected_strides_3_zero, 3});
205215
}
206216

207-
TEST(DimOrderUtilTest, StrideToDimOrder) {
217+
TEST_F(DimOrderUtilTest, StrideToDimOrder) {
208218
executorch::aten::SizesType strides[3] = {5, 1, 15};
209219
executorch::aten::DimOrderType dim_order[3] = {0, 0, 0};
210220

@@ -216,7 +226,7 @@ TEST(DimOrderUtilTest, StrideToDimOrder) {
216226
check_dim_order_eq(dim_order, expected_dim_order);
217227
}
218228

219-
TEST(DimOrderUtilTest, StrideToDimOrderSameStrides) {
229+
TEST_F(DimOrderUtilTest, StrideToDimOrderSameStrides) {
220230
executorch::aten::SizesType strides[4] = {4, 3, 1, 1};
221231
executorch::aten::DimOrderType dim_order[4] = {0, 0, 0, 0};
222232

@@ -227,7 +237,7 @@ TEST(DimOrderUtilTest, StrideToDimOrderSameStrides) {
227237
check_dim_order_eq(dim_order, expected_dim_order);
228238
}
229239

230-
TEST(DimOrderUtilTest, IsDefaultDimOrderTest) {
240+
TEST_F(DimOrderUtilTest, IsDefaultDimOrderTest) {
231241
for (const auto i : c10::irange(1, 7)) {
232242
std::vector<executorch::aten::DimOrderType> dim_order(i);
233243
std::iota(dim_order.begin(), dim_order.end(), 0);
@@ -240,7 +250,7 @@ TEST(DimOrderUtilTest, IsDefaultDimOrderTest) {
240250
}
241251
}
242252

243-
TEST(DimOrderUtilTest, IsDefaultDimOrderFailCasesTest) {
253+
TEST_F(DimOrderUtilTest, IsDefaultDimOrderFailCasesTest) {
244254
// Dims is default order but have two elements swapped
245255
for (const auto i : c10::irange(3, 8)) {
246256
std::vector<executorch::aten::DimOrderType> dim_order(i);
@@ -261,7 +271,7 @@ TEST(DimOrderUtilTest, IsDefaultDimOrderFailCasesTest) {
261271
}
262272
}
263273

264-
TEST(DimOrderUtilTest, IsChannelsLastDimOrderTest) {
274+
TEST_F(DimOrderUtilTest, IsChannelsLastDimOrderTest) {
265275
executorch::aten::DimOrderType dim_order_4d[4] = {0, 2, 3, 1};
266276
executorch::aten::DimOrderType dim_order_5d[5] = {0, 2, 3, 4, 1};
267277

@@ -273,7 +283,7 @@ TEST(DimOrderUtilTest, IsChannelsLastDimOrderTest) {
273283
EXPECT_FALSE(is_contiguous_dim_order(dim_order_5d, 5));
274284
}
275285

276-
TEST(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) {
286+
TEST_F(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) {
277287
// Non 4D and 5D dim order returns false
278288
executorch::aten::DimOrderType dim_order_3d[4] = {1, 2, 0};
279289
executorch::aten::DimOrderType dim_order_6d[6] = {0, 2, 3, 4, 5, 1};
@@ -287,3 +297,52 @@ TEST(DimOrderUtilTest, IsChannelsLastDimOrderFailCasesTest) {
287297
EXPECT_FALSE(is_channels_last_dim_order(dim_order_4d, 4));
288298
EXPECT_FALSE(is_channels_last_dim_order(dim_order_5d, 5));
289299
}
300+
301+
TEST_F(DimOrderUtilTest, DimOrderWithAllDuplicatesReturnsError) {
302+
executorch::aten::SizesType sizes[3] = {2, 3, 4};
303+
executorch::aten::SizesType dim_order[3] = {0, 0, 0};
304+
executorch::aten::SizesType strides[3] = {0, 0, 0};
305+
306+
auto error = dim_order_to_stride(sizes, dim_order, 3, strides);
307+
EXPECT_EQ(error, Error::InvalidArgument);
308+
}
309+
310+
TEST_F(DimOrderUtilTest, DimOrderWithPartialDuplicateReturnsError) {
311+
executorch::aten::SizesType sizes[3] = {2, 3, 4};
312+
executorch::aten::SizesType dim_order[3] = {0, 1, 1};
313+
executorch::aten::SizesType strides[3] = {0, 0, 0};
314+
315+
auto error = dim_order_to_stride(sizes, dim_order, 3, strides);
316+
EXPECT_EQ(error, Error::InvalidArgument);
317+
}
318+
319+
TEST_F(DimOrderUtilTest, DimOrderWithMissingValueReturnsError) {
320+
executorch::aten::SizesType sizes[3] = {2, 3, 4};
321+
executorch::aten::SizesType dim_order[3] = {1, 2, 2};
322+
executorch::aten::SizesType strides[3] = {0, 0, 0};
323+
324+
auto error = dim_order_to_stride(sizes, dim_order, 3, strides);
325+
EXPECT_EQ(error, Error::InvalidArgument);
326+
}
327+
328+
TEST_F(DimOrderUtilTest, DimOrderWithOutOfBoundsValueReturnsError) {
329+
executorch::aten::SizesType sizes[3] = {2, 3, 4};
330+
executorch::aten::SizesType dim_order[3] = {0, 1, 5};
331+
executorch::aten::SizesType strides[3] = {0, 0, 0};
332+
333+
auto error = dim_order_to_stride(sizes, dim_order, 3, strides);
334+
EXPECT_EQ(error, Error::InvalidArgument);
335+
}
336+
337+
TEST_F(DimOrderUtilTest, TooManyDimsReturnsError) {
338+
constexpr size_t kTooManyDims =
339+
executorch::runtime::kTensorDimensionLimit + 1;
340+
std::vector<executorch::aten::SizesType> sizes(kTooManyDims, 1);
341+
std::vector<executorch::aten::SizesType> dim_order(kTooManyDims);
342+
std::iota(dim_order.begin(), dim_order.end(), 0);
343+
std::vector<executorch::aten::SizesType> strides(kTooManyDims, 0);
344+
345+
auto error = dim_order_to_stride(
346+
sizes.data(), dim_order.data(), kTooManyDims, strides.data());
347+
EXPECT_EQ(error, Error::InvalidArgument);
348+
}

0 commit comments

Comments
 (0)