Skip to content

Commit e093efd

Browse files
committed
add tests
1 parent 372f0bf commit e093efd

2 files changed

Lines changed: 45 additions & 9 deletions

File tree

onnxruntime/core/common/safeint.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,21 @@ template <typename R, typename T, typename U, typename... Rest>
6363

6464
// SafeMultiply(T, U, T&) requires the first argument and result to share
6565
// the same type. Cast the first operand to R so the result is directly in R.
66-
R result{};
67-
if constexpr (std::is_same_v<R, T>) {
68-
result = a;
69-
} else {
70-
if (!SafeCast(a, result)) {
71-
SafeIntDefaultExceptionHandler::SafeIntOnOverflow();
72-
}
66+
R cast_a{};
67+
if (!SafeCast(a, cast_a)) {
68+
SafeIntDefaultExceptionHandler::SafeIntOnOverflow();
7369
}
7470

75-
if (!SafeMultiply(result, b, result)) {
71+
R result{};
72+
if (!SafeMultiply(cast_a, b, result)) {
7673
SafeIntDefaultExceptionHandler::SafeIntOnOverflow();
7774
}
7875

7976
if constexpr (sizeof...(rest) > 0) {
8077
return SafeMul<R>(result, rest...);
78+
} else {
79+
return result;
8180
}
82-
return result;
8381
}
8482

8583
} // namespace onnxruntime
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/safeint.h"
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <limits>
9+
10+
#include "gtest/gtest.h"
11+
12+
namespace onnxruntime::test {
13+
14+
static_assert(is_supported_integer_v<int>);
15+
static_assert(is_supported_integer_v<uint8_t>);
16+
static_assert(!is_supported_integer_v<bool>);
17+
18+
TEST(SafeIntTest, SafeMulMultipliesOperands) {
19+
EXPECT_EQ(SafeMul<size_t>(size_t{2}, 3U), size_t{6});
20+
EXPECT_EQ(SafeMul<int>(-2, 3, 4), -24);
21+
}
22+
23+
TEST(SafeIntTest, SafeMulHandlesSameVariableOperands) {
24+
const int value = 7;
25+
EXPECT_EQ(SafeMul<int>(value, value), 49);
26+
}
27+
28+
#ifndef ORT_NO_EXCEPTIONS
29+
TEST(SafeIntTest, SafeMulThrowsOnInitialCastOverflow) {
30+
EXPECT_THROW((SafeMul<uint32_t>(-1, 2)), OnnxRuntimeException);
31+
}
32+
33+
TEST(SafeIntTest, SafeMulThrowsOnMultiplyOverflow) {
34+
EXPECT_THROW((SafeMul<int>(std::numeric_limits<int>::max(), 2)), OnnxRuntimeException);
35+
}
36+
#endif
37+
38+
} // namespace onnxruntime::test

0 commit comments

Comments
 (0)