Skip to content

Commit ec5e8e4

Browse files
authored
Fix overflows in et (pytorch#19057)
### Summary Check various overflows
1 parent 476a7ef commit ec5e8e4

7 files changed

Lines changed: 34 additions & 10 deletions

File tree

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/util/safe_numerics.h>
910
#include <dlfcn.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/core/error.h>
@@ -459,8 +460,10 @@ class ET_EXPERIMENTAL MetalBackend final
459460

460461
ET_LOG(Debug, "MetalBackend n_outputs %zd generated", n_outputs);
461462

463+
size_t n_io_sum = 0;
462464
ET_CHECK_OR_RETURN_ERROR(
463-
n_inputs + n_outputs == args.size(),
465+
!c10::add_overflows(n_inputs, n_outputs, &n_io_sum) &&
466+
n_io_sum == args.size(),
464467
InvalidArgument,
465468
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
466469
n_inputs,

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/util/safe_numerics.h>
910
#include <cuda_runtime.h>
1011
#include <executorch/runtime/backend/interface.h>
1112
#include <executorch/runtime/backend/options.h>
@@ -550,8 +551,10 @@ class ET_EXPERIMENTAL CudaBackend final
550551

551552
setCurrentCUDAStream(handle->get_cuda_stream(), 0);
552553

554+
size_t n_io_sum = 0;
553555
ET_CHECK_OR_RETURN_ERROR(
554-
n_inputs + n_outputs == args.size(),
556+
!c10::add_overflows(n_inputs, n_outputs, &n_io_sum) &&
557+
n_io_sum == args.size(),
555558
InvalidArgument,
556559
"number of user input %zd and output %zd generated from AOT Inductor does not match ET runner's %zd. Exit.",
557560
n_inputs,

examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/util/safe_numerics.h>
910
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h>
1011
#include <numeric>
1112
using executorch::aten::TensorImpl;
@@ -248,9 +249,12 @@ Result<uint64_t> PromptProcessor<T>::prefill(
248249
ET_CHECK_MSG(
249250
start_pos == 0, "Bert model doesn't support multi-turn conversation.");
250251
} else if (!enable_attention_sink) {
252+
int64_t end_pos = 0;
251253
ET_CHECK_MSG(
252-
(start_pos + num_prompt_tokens) <=
253-
(metadata_.context_len - metadata_.ar_len),
254+
!c10::add_overflows(
255+
start_pos, static_cast<int64_t>(num_prompt_tokens), &end_pos) &&
256+
end_pos <= static_cast<int64_t>(metadata_.context_len) -
257+
static_cast<int64_t>(metadata_.ar_len),
254258
"The sequence length exceeds the maximum limit that the prompt processor can handle.");
255259
}
256260

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// A llama 3.2 runner that includes preprocessing and post processing
1010
// logic. The module takes in a string as input and emits a string as output.
1111

12+
#include <c10/util/safe_numerics.h>
1213
#include <executorch/examples/models/llama/runner/runner.h>
1314
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1415
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
@@ -433,8 +434,11 @@ Error Runner<T>::generate_from_prompt_or_file(
433434
}
434435
int num_prompt_tokens = prompt_tokens.size();
435436
ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
437+
int64_t end_pos = 0;
436438
ET_CHECK_MSG(
437-
cur_pos_ + num_prompt_tokens < seq_len,
439+
!c10::add_overflows(
440+
cur_pos_, static_cast<int64_t>(num_prompt_tokens), &end_pos) &&
441+
end_pos < static_cast<int64_t>(seq_len),
438442
"sequence length exceeded - please increase the seq_len value");
439443

440444
// Prompt Processor first

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1010

11+
#include <c10/util/safe_numerics.h>
12+
1113
#include <executorch/extension/flat_tensor/serialize/flat_tensor_generated.h>
1214
#include <executorch/extension/flat_tensor/serialize/flat_tensor_header.h>
1315

@@ -73,9 +75,13 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
7375
key.data(),
7476
segments->size());
7577
// Validate the segment.
78+
uint64_t seg_end = 0;
7679
ET_CHECK_OR_RETURN_ERROR(
77-
(segments->Get(segment_index)->offset() +
78-
segments->Get(segment_index)->size()) <= segment_end_offset,
80+
!c10::add_overflows(
81+
static_cast<uint64_t>(segments->Get(segment_index)->offset()),
82+
static_cast<uint64_t>(segments->Get(segment_index)->size()),
83+
&seg_end) &&
84+
seg_end <= static_cast<uint64_t>(segment_end_offset),
7985
InvalidExternalData,
8086
"Invalid segment offset %" PRIu64
8187
" is larger than the segment_base_offset + segment_data_size %" PRIu64

runtime/core/array_ref.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <cstdint>
3131

3232
#include <c10/util/irange.h>
33+
#include <c10/util/safe_numerics.h>
3334
#include <executorch/runtime/platform/assert.h>
3435

3536
namespace executorch {
@@ -161,7 +162,8 @@ class ArrayRef final {
161162
/// slice(n, m) - Take M elements of the array starting at element N
162163
ArrayRef<T> slice(size_t N, size_t M) const {
163164
// cant slice longer then the array
164-
ET_CHECK(N + M <= size());
165+
size_t end = 0;
166+
ET_CHECK(!c10::add_overflows(N, M, &end) && end <= size());
165167
return ArrayRef<T>(data() + N, M);
166168
}
167169

runtime/core/hierarchical_allocator.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <c10/util/safe_numerics.h>
1213

1314
#include <executorch/runtime/core/memory_allocator.h>
1415
#include <executorch/runtime/core/result.h>
@@ -58,8 +59,9 @@ class HierarchicalAllocator final {
5859
size_t offset_bytes,
5960
size_t size_bytes) {
6061
// Check for integer overflow in offset_bytes + size_bytes.
62+
size_t end_bytes = 0;
6163
ET_CHECK_OR_RETURN_ERROR(
62-
size_bytes <= SIZE_MAX - offset_bytes,
64+
!c10::add_overflows(offset_bytes, size_bytes, &end_bytes),
6365
InvalidArgument,
6466
"Integer overflow in offset_bytes (%" ET_PRIsize_t
6567
") + size_bytes (%" ET_PRIsize_t ")",
@@ -73,7 +75,7 @@ class HierarchicalAllocator final {
7375
buffers_.size());
7476
Span<uint8_t> buffer = buffers_[memory_id];
7577
ET_CHECK_OR_RETURN_ERROR(
76-
offset_bytes + size_bytes <= buffer.size(),
78+
end_bytes <= buffer.size(),
7779
MemoryAllocationFailed,
7880
"offset_bytes (%" ET_PRIsize_t ") + size_bytes (%" ET_PRIsize_t
7981
") >= allocator size (%" ET_PRIsize_t

0 commit comments

Comments
 (0)