Skip to content

Commit ffbc5e8

Browse files
tianleiwuCopilot
andauthored
Handle int overflow in rnn (#28003)
### Description Fixes two overflow/underflow bugs in the CPU RNN kernel (`rnn.cc`): - **`SafeInt` for GEMM M-dimension**: `seq_length * batch_size` was computed as a raw `int64_t` multiply before `narrow<int>()`, meaning an overflow would be UB before the check could fire. Replaced with `SafeInt<int64_t>(seq_length) * batch_size` for a checked multiply. - **`seq_length == 0` guard in `Assign_Y_h`**: For the forward direction, `last_time_step = seq_length - 1` underflows to `-1` when `seq_length == 0`, producing a negative `y_offset` and out-of-bounds read. Added an early-exit that zero-fills Y_h for the direction and returns. Also handles `sequence_lens[batch] == 0` (same underflow path), zeroing the affected batch slot and skipping via `continue`. ### Motivation and Context Silent UB from integer overflow/underflow in shape-derived index arithmetic can corrupt memory or produce incorrect results without any diagnostic signal. These cases are legal per the ONNX spec (empty sequences, per-batch zero-length sequences) and must be handled explicitly. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
1 parent 50129c9 commit ffbc5e8

4 files changed

Lines changed: 224 additions & 20 deletions

File tree

onnxruntime/core/common/safeint.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,48 @@ class SafeIntExceptionHandler<onnxruntime::OnnxRuntimeException> {
3636
#if defined(__GNUC__)
3737
#pragma GCC diagnostic pop
3838
#endif
39+
40+
#include <type_traits>
41+
42+
namespace onnxruntime {
43+
44+
template <typename T>
45+
using remove_cvref_t = std::remove_cv_t<std::remove_reference_t<T>>;
46+
47+
template <typename T>
48+
inline constexpr bool is_supported_integer_v =
49+
std::is_integral_v<remove_cvref_t<T>> && !std::is_same_v<remove_cvref_t<T>, bool>;
50+
51+
//------------------------------------------------------------------------------
52+
// Safe multiplication of two or more integer values into an explicit result type R.
53+
// Throws OnnxRuntimeException on overflow.
54+
//------------------------------------------------------------------------------
55+
template <typename R, typename T, typename U, typename... Rest>
56+
[[nodiscard]] R SafeMul(T a, U b, Rest... rest) {
57+
static_assert(is_supported_integer_v<R>,
58+
"SafeMul requires an integral result type (excluding bool)");
59+
static_assert(is_supported_integer_v<T> && is_supported_integer_v<U>,
60+
"SafeMul requires integral operand types (excluding bool)");
61+
static_assert((is_supported_integer_v<Rest> && ...),
62+
"SafeMul requires integral operand types (excluding bool)");
63+
64+
// SafeMultiply(T, U, T&) requires the first argument and result to share
65+
// the same type. Cast the first operand to R so the result is directly in R.
66+
R cast_a{};
67+
if (!SafeCast(a, cast_a)) {
68+
SafeIntDefaultExceptionHandler::SafeIntOnOverflow();
69+
}
70+
71+
R result{};
72+
if (!SafeMultiply(cast_a, b, result)) {
73+
SafeIntDefaultExceptionHandler::SafeIntOnOverflow();
74+
}
75+
76+
if constexpr (sizeof...(rest) > 0) {
77+
return SafeMul<R>(result, rest...);
78+
} else {
79+
return result;
80+
}
81+
}
82+
83+
} // namespace onnxruntime

onnxruntime/core/providers/cpu/rnn/rnn.cc

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "core/providers/cpu/rnn/rnn.h"
55

6+
#include "core/common/narrow.h"
67
#include "core/common/safeint.h"
78
#include "core/framework/op_kernel_context_internal.h"
89
#include "core/providers/cpu/rnn/rnn_activation_functors.h"
@@ -84,15 +85,32 @@ void ApplyActivationToBatches(const Tensor* sequence_lens, const T* h_prev, T* Y
8485
template <typename T>
8586
void Assign_Y_h(const T* Y_buffer_data, Tensor* Y_h, const Tensor* sequence_lens,
8687
int64_t num_directions, int direction, bool isReverse, int64_t batch_size, int64_t seq_length, int64_t hidden_size) {
88+
if (seq_length == 0) {
89+
// No sequence data was processed; zero out Y_h for this direction.
90+
const size_t y_h_direction_size = SafeMul<size_t>(batch_size, hidden_size);
91+
const size_t Y_h_direction_offset = SafeMul<size_t>(direction, y_h_direction_size);
92+
math::Set<T, CPUMathUtil>(y_h_direction_size, T{0},
93+
Y_h->MutableData<T>() + Y_h_direction_offset, &CPUMathUtil::Instance());
94+
return;
95+
}
96+
8797
for (int batch = 0; batch < batch_size; batch++) {
8898
int64_t last_time_step = isReverse ? 0 : seq_length - 1;
89-
if (nullptr != sequence_lens && !isReverse)
99+
if (nullptr != sequence_lens && !isReverse) {
90100
last_time_step = sequence_lens->Data<int>()[batch] - 1;
101+
if (last_time_step < 0) {
102+
// sequence_lens[batch] == 0: no data was processed for this batch; zero out Y_h.
103+
int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size;
104+
math::Set<T, CPUMathUtil>(narrow<size_t>(hidden_size), T{0},
105+
Y_h->MutableData<T>() + Y_h_offset, &CPUMathUtil::Instance());
106+
continue;
107+
}
108+
}
91109
int64_t y_offset = last_time_step * num_directions * batch_size * hidden_size +
92110
direction * batch_size * hidden_size +
93111
batch * hidden_size;
94112
int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size;
95-
math::CopyVector<T, CPUMathUtil>(static_cast<int>(hidden_size), Y_buffer_data + y_offset,
113+
math::CopyVector<T, CPUMathUtil>(narrow<int>(hidden_size), Y_buffer_data + y_offset,
96114
Y_h->MutableData<T>() + Y_h_offset,
97115
&CPUMathUtil::Instance());
98116
}
@@ -109,7 +127,7 @@ void ClearMissingFrames(T* Y_buffer_data, const Tensor* sequence_lens,
109127
seq * num_directions * batch_size * hidden_size +
110128
direction * batch_size * hidden_size +
111129
batch * hidden_size;
112-
math::Set<T, CPUMathUtil>(onnxruntime::narrow<size_t>(hidden_size), 0, Y_buffer_data + offset, &CPUMathUtil::Instance());
130+
math::Set<T, CPUMathUtil>(narrow<size_t>(hidden_size), 0, Y_buffer_data + offset, &CPUMathUtil::Instance());
113131
}
114132
}
115133
}
@@ -155,7 +173,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
155173
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
156174

157175
// X * W^t, each direction has shape of [seq_length, batch_size, hidden_size]
158-
auto x_matmul_data = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * seq_length * batch_size * hidden_size_);
176+
auto x_matmul_data = alloc->Alloc(SafeMul<size_t>(sizeof(float), seq_length, batch_size, hidden_size_));
159177
BufferUniquePtr x_matmul_buffer(x_matmul_data, BufferDeleter(alloc));
160178
auto* x_matmul_w_buffer_data = static_cast<float*>(x_matmul_buffer.get());
161179

@@ -165,7 +183,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
165183
if (Y != nullptr)
166184
Y_buffer_data = Y->MutableData<float>();
167185
else {
168-
Y_data = alloc->Alloc(SafeInt<size_t>(sizeof(float)) * seq_length * num_directions * batch_size * hidden_size_);
186+
Y_data = alloc->Alloc(SafeMul<size_t>(sizeof(float), seq_length, num_directions, batch_size, hidden_size_));
169187
Y_matmul_buffer = BufferUniquePtr(Y_data, BufferDeleter(alloc));
170188
Y_buffer_data = static_cast<float*>(Y_matmul_buffer.get());
171189
}
@@ -177,20 +195,20 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
177195
bool isReverse = direction_ == "reverse" || direction == 1;
178196

179197
if (B != nullptr) {
180-
EigenMatrixMapRowMajor<float>(x_matmul_w_buffer_data, seq_length * SafeInt<size_t>(batch_size), onnxruntime::narrow<size_t>(hidden_size_)).rowwise() =
181-
ConstEigenVectorMap<float>(B->Data<float>() + direction * 2 * hidden_size_, onnxruntime::narrow<size_t>(hidden_size_)).transpose() +
182-
ConstEigenVectorMap<float>(B->Data<float>() + direction * 2 * hidden_size_ + hidden_size_, onnxruntime::narrow<size_t>(hidden_size_)).transpose();
198+
EigenMatrixMapRowMajor<float>(x_matmul_w_buffer_data, SafeMul<size_t>(seq_length, batch_size), narrow<size_t>(hidden_size_)).rowwise() =
199+
ConstEigenVectorMap<float>(B->Data<float>() + direction * 2 * hidden_size_, narrow<size_t>(hidden_size_)).transpose() +
200+
ConstEigenVectorMap<float>(B->Data<float>() + direction * 2 * hidden_size_ + hidden_size_, narrow<size_t>(hidden_size_)).transpose();
183201
} else {
184-
math::Set<float, CPUMathUtil>(seq_length * batch_size * SafeInt<size_t>(hidden_size_), 0, x_matmul_w_buffer_data, &CPUMathUtil::Instance());
202+
math::Set<float, CPUMathUtil>(SafeMul<size_t>(seq_length, batch_size, hidden_size_), 0, x_matmul_w_buffer_data, &CPUMathUtil::Instance());
185203
}
186204

187205
// X * W[direction]^t + B
188206
math::Gemm<float>(
189207
CblasNoTrans,
190208
CblasTrans,
191-
static_cast<int>(seq_length * batch_size),
192-
static_cast<int>(hidden_size_),
193-
static_cast<int>(input_size),
209+
SafeMul<int>(seq_length, batch_size),
210+
narrow<int>(hidden_size_),
211+
narrow<int>(input_size),
194212
1,
195213
X.Data<float>(),
196214
W.Data<float>() + direction * hidden_size_ * input_size,
@@ -202,7 +220,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
202220
int64_t time_step = isReverse ? (seq_length - t - 1) : t;
203221
int64_t Y_frame_offset = (time_step * num_directions + direction) * Y_frame_size;
204222
float* Y_buffer_data_current_frame = Y_buffer_data + Y_frame_offset;
205-
auto y_frame_mat = EigenMatrixMapRowMajor<float>(Y_buffer_data_current_frame, onnxruntime::narrow<size_t>(batch_size), onnxruntime::narrow<size_t>(hidden_size_));
223+
auto y_frame_mat = EigenMatrixMapRowMajor<float>(Y_buffer_data_current_frame, narrow<size_t>(batch_size), narrow<size_t>(hidden_size_));
206224

207225
const float* h_prev = nullptr;
208226
if (t == 0) {
@@ -224,21 +242,21 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
224242
math::Gemm<float>(
225243
CblasNoTrans,
226244
CblasTrans,
227-
static_cast<int>(batch_size),
228-
static_cast<int>(hidden_size_),
229-
static_cast<int>(hidden_size_),
245+
narrow<int>(batch_size),
246+
narrow<int>(hidden_size_),
247+
narrow<int>(hidden_size_),
230248
1,
231249
h_prev,
232250
R.Data<float>() + direction * hidden_size_ * hidden_size_,
233251
0,
234252
Y_buffer_data_current_frame,
235253
tp, &mlas_backend_kernel_selector_config_);
236254
} else {
237-
math::Set<float, CPUMathUtil>(batch_size * SafeInt<size_t>(hidden_size_), 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance());
255+
math::Set<float, CPUMathUtil>(SafeMul<size_t>(batch_size, hidden_size_), 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance());
238256
}
239257

240258
// X[time_step] * W^t + H_t_1 * R^t
241-
y_frame_mat += EigenMatrixMapRowMajor<float>(&x_matmul_w_buffer_data[time_step * Y_frame_size], onnxruntime::narrow<size_t>(batch_size), onnxruntime::narrow<size_t>(hidden_size_));
259+
y_frame_mat += EigenMatrixMapRowMajor<float>(&x_matmul_w_buffer_data[time_step * Y_frame_size], narrow<size_t>(batch_size), narrow<size_t>(hidden_size_));
242260

243261
// apply activation
244262
ApplyActivationToBatches<float>(sequence_lens, h_prev, Y_buffer_data_current_frame,
@@ -258,10 +276,10 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
258276
}
259277

260278
if (Y != nullptr)
261-
DumpMatrix("Y", Y_buffer_data, (int)(seq_length * num_directions * batch_size), (int)hidden_size_);
279+
DumpMatrix("Y", Y_buffer_data, SafeMul<int>(seq_length, num_directions, batch_size), narrow<int>(hidden_size_));
262280

263281
if (Y_h != nullptr)
264-
DumpMatrix("Y_h", Y_h->Data<float>(), (int)(num_directions * batch_size), (int)hidden_size_);
282+
DumpMatrix("Y_h", Y_h->Data<float>(), SafeMul<int>(num_directions, batch_size), narrow<int>(hidden_size_));
265283

266284
return Status::OK();
267285
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/common/safeint.h"
5+
6+
#include <cstddef>
7+
#include <cstdint>
8+
#include <limits>
9+
10+
#include "gtest/gtest.h"
11+
12+
namespace onnxruntime::test {
13+
14+
static_assert(is_supported_integer_v<int>);
15+
static_assert(is_supported_integer_v<uint8_t>);
16+
static_assert(!is_supported_integer_v<bool>);
17+
18+
TEST(SafeIntTest, SafeMulMultipliesOperands) {
19+
EXPECT_EQ(SafeMul<size_t>(size_t{2}, 3U), size_t{6});
20+
EXPECT_EQ(SafeMul<int>(-2, 3, 4), -24);
21+
}
22+
23+
TEST(SafeIntTest, SafeMulHandlesSameVariableOperands) {
24+
const int value = 7;
25+
EXPECT_EQ(SafeMul<int>(value, value), 49);
26+
}
27+
28+
#ifndef ORT_NO_EXCEPTIONS
29+
TEST(SafeIntTest, SafeMulThrowsOnInitialCastOverflow) {
30+
EXPECT_THROW((void)SafeMul<uint32_t>(-1, 2), OnnxRuntimeException);
31+
}
32+
33+
TEST(SafeIntTest, SafeMulThrowsOnMultiplyOverflow) {
34+
EXPECT_THROW((void)SafeMul<int>(std::numeric_limits<int>::max(), 2), OnnxRuntimeException);
35+
}
36+
#endif
37+
38+
} // namespace onnxruntime::test

onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include <cmath>
5+
46
#include "core/providers/cpu/rnn/rnn.h"
57
#include "gtest/gtest.h"
68
#include "test/providers/provider_test_utils.h"
@@ -883,5 +885,106 @@ TEST(RNNTest, RNN_with_invalid_activation_load_failure) {
883885
{kCudaExecutionProvider, kTensorrtExecutionProvider});
884886
}
885887

888+
// Test that seq_length == 0 produces zero-filled Y and Y_h without crashing.
889+
TEST(RNNTest, RNN_seq_length_zero) {
890+
auto cpu = DefaultCpuExecutionProvider();
891+
if (!cpu) GTEST_SKIP() << "CPU EP not available in this build.";
892+
893+
OpTester test("RNN");
894+
int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 0;
895+
896+
test.AddAttribute("activations", vector<string>(num_directions, "Tanh"));
897+
test.AddAttribute("direction", "forward");
898+
test.AddAttribute("hidden_size", hidden_size);
899+
900+
std::vector<int64_t> X_dims = {seq_length, batch_size, input_size};
901+
std::vector<float> X_data{};
902+
test.AddInput<float>("X", X_dims, X_data);
903+
904+
std::vector<int64_t> W_dims = {num_directions, hidden_size, input_size};
905+
std::vector<float> W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f});
906+
test.AddInput<float>("W", W_dims, W_data);
907+
908+
std::vector<int64_t> R_dims = {num_directions, hidden_size, hidden_size};
909+
std::vector<float> R_data(hidden_size * hidden_size, 0.f);
910+
test.AddInput<float>("R", R_dims, R_data);
911+
912+
// Y: shape [0, 1, 2, 3] -> empty
913+
std::vector<int64_t> Y_dims = {seq_length, num_directions, batch_size, hidden_size};
914+
std::vector<float> Y_data{};
915+
test.AddOutput<float>("Y", Y_dims, Y_data);
916+
917+
// Y_h: shape [1, 2, 3] -> all zeros
918+
std::vector<int64_t> Y_h_dims{num_directions, batch_size, hidden_size};
919+
std::vector<float> Y_h_data(num_directions * batch_size * hidden_size, 0.f);
920+
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
921+
test.ConfigEp(std::move(cpu)).RunWithConfig();
922+
}
923+
924+
// Test that per-batch sequence_lens containing 0 produces zero-filled Y_h for those batches.
925+
TEST(RNNTest, RNN_forward_sequence_lens_with_zero) {
926+
auto cpu = DefaultCpuExecutionProvider();
927+
if (!cpu) GTEST_SKIP() << "CPU EP not available in this build.";
928+
929+
OpTester test("RNN");
930+
int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 2;
931+
932+
test.AddAttribute("activations", vector<string>(num_directions, "Tanh"));
933+
test.AddAttribute("direction", "forward");
934+
test.AddAttribute("hidden_size", hidden_size);
935+
936+
// X shape: [seq_length=2, batch_size=2, input_size=2]
937+
std::vector<int64_t> X_dims = {seq_length, batch_size, input_size};
938+
std::vector<float> X_data({0.1f, 0.2f,
939+
0.3f, 0.4f,
940+
0.5f, 0.6f,
941+
0.7f, 0.8f});
942+
test.AddInput<float>("X", X_dims, X_data);
943+
944+
std::vector<int64_t> W_dims = {num_directions, hidden_size, input_size};
945+
std::vector<float> W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f});
946+
test.AddInput<float>("W", W_dims, W_data);
947+
948+
std::vector<int64_t> R_dims = {num_directions, hidden_size, hidden_size};
949+
std::vector<float> R_data(hidden_size * hidden_size, 0.f);
950+
test.AddInput<float>("R", R_dims, R_data);
951+
952+
std::vector<int64_t> B_dims = {num_directions, 2 * hidden_size};
953+
std::vector<float> B_data(2 * hidden_size, 0.f);
954+
test.AddInput<float>("B", B_dims, B_data);
955+
956+
// batch 0 has sequence_lens=2, batch 1 has sequence_lens=0
957+
std::vector<int64_t> sequence_lens_dims{batch_size};
958+
std::vector<int> sequence_lens_data{2, 0};
959+
test.AddInput<int>("sequence_lens", sequence_lens_dims, sequence_lens_data);
960+
961+
std::vector<int64_t> initial_h_dims = {num_directions, batch_size, hidden_size};
962+
std::vector<float> initial_h_data(num_directions * batch_size * hidden_size, 0.f);
963+
test.AddInput<float>("initial_h", initial_h_dims, initial_h_data);
964+
965+
// Y output is optional; skip it to keep test simple.
966+
test.AddOptionalOutputEdge<float>();
967+
968+
// Y_h: shape [1, 2, 3]
969+
// batch 0 gets the result of forward pass at last time step (seq_length-1=1).
970+
// batch 1 has sequence_lens=0 so Y_h should be zero.
971+
//
972+
// For batch 0:
973+
// time_step 0: X=[0.1, 0.2], Y = tanh(X * W^T) = tanh([-0.1*0.1+0.2*0.2, 1*0.1-2*0.2, -1*0.1+3*0.2])
974+
// = tanh([0.03, -0.3, 0.5])
975+
// time_step 1: X=[0.5, 0.6], Y = tanh(X * W^T + H_prev * R^T)
976+
// R is zero, so Y = tanh([-0.1*0.5+0.2*0.6, 1*0.5-2*0.6, -1*0.5+3*0.6])
977+
// = tanh([0.07, -0.7, 1.3])
978+
float y_h_batch0_f0 = std::tanh(0.07f);
979+
float y_h_batch0_f1 = std::tanh(-0.7f);
980+
float y_h_batch0_f2 = std::tanh(1.3f);
981+
982+
std::vector<int64_t> Y_h_dims{num_directions, batch_size, hidden_size};
983+
std::vector<float> Y_h_data{y_h_batch0_f0, y_h_batch0_f1, y_h_batch0_f2,
984+
0.f, 0.f, 0.f};
985+
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
986+
test.ConfigEp(std::move(cpu)).RunWithConfig();
987+
}
988+
886989
} // namespace test
887990
} // namespace onnxruntime

0 commit comments

Comments
 (0)