Skip to content

Commit 29c18de

Browse files
Use uint64_t for FlatTensor segment end
Differential Revision: D106710218 Pull Request resolved: pytorch#19860
1 parent 10e2eec commit 29c18de

1 file changed

Lines changed: 35 additions & 6 deletions

File tree

extension/flat_tensor/flat_tensor_data_map.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <executorch/runtime/core/span.h>
2222
#include <executorch/runtime/platform/compiler.h>
2323

24+
#include <cinttypes>
25+
2426
using executorch::runtime::Error;
2527
using executorch::runtime::FreeableBuffer;
2628
using executorch::runtime::Result;
@@ -52,7 +54,7 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
5254
flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data,
5355
const flatbuffers::Vector<
5456
flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
55-
size_t segment_end_offset) {
57+
uint64_t segment_end_offset) {
5658
// Linear search by name.
5759
if (named_data == nullptr) {
5860
return Error::NotFound;
@@ -81,19 +83,34 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
8183
static_cast<uint64_t>(segments->Get(segment_index)->offset()),
8284
static_cast<uint64_t>(segments->Get(segment_index)->size()),
8385
&seg_end) &&
84-
seg_end <= static_cast<uint64_t>(segment_end_offset),
86+
seg_end <= segment_end_offset,
8587
InvalidExternalData,
8688
"Invalid segment offset %" PRIu64
8789
" is larger than the segment_base_offset + segment_data_size %" PRIu64
8890
"; malformed PTD file.",
8991
segments->Get(segment_index)->offset(),
90-
static_cast<uint64_t>(segment_end_offset));
92+
segment_end_offset);
9193
return found;
9294
}
9395
}
9496
return Error::NotFound;
9597
}
9698

99+
Result<uint64_t> get_segment_end_offset(const FlatTensorHeader& header) {
100+
uint64_t segment_end_offset = 0;
101+
ET_CHECK_OR_RETURN_ERROR(
102+
!c10::add_overflows(
103+
header.segment_base_offset,
104+
header.segment_data_size,
105+
&segment_end_offset),
106+
InvalidExternalData,
107+
"segment_base_offset %" PRIu64 " + segment_data_size %" PRIu64
108+
" overflows uint64_t; malformed PTD file.",
109+
header.segment_base_offset,
110+
header.segment_data_size);
111+
return segment_end_offset;
112+
}
113+
97114
Result<const TensorLayout> create_tensor_layout(
98115
const flat_tensor_flatbuffer::TensorLayout* tensor_layout) {
99116
ScalarType scalar_type =
@@ -111,11 +128,15 @@ Result<const TensorLayout> create_tensor_layout(
111128

112129
ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout(
113130
executorch::aten::string_view key) const {
131+
Result<uint64_t> segment_end_offset = get_segment_end_offset(header_);
132+
if (!segment_end_offset.ok()) {
133+
return segment_end_offset.error();
134+
}
114135
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
115136
key,
116137
flat_tensor_->named_data(),
117138
flat_tensor_->segments(),
118-
header_.segment_base_offset + header_.segment_data_size);
139+
segment_end_offset.get());
119140
if (!named_data.ok()) {
120141
return named_data.error();
121142
}
@@ -124,11 +145,15 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout(
124145

125146
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
126147
executorch::aten::string_view key) const {
148+
Result<uint64_t> segment_end_offset = get_segment_end_offset(header_);
149+
if (!segment_end_offset.ok()) {
150+
return segment_end_offset.error();
151+
}
127152
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
128153
key,
129154
flat_tensor_->named_data(),
130155
flat_tensor_->segments(),
131-
header_.segment_base_offset + header_.segment_data_size);
156+
segment_end_offset.get());
132157
if (!named_data.ok()) {
133158
return named_data.error();
134159
}
@@ -148,11 +173,15 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
148173
ET_UNUSED executorch::aten::string_view key,
149174
ET_UNUSED void* buffer,
150175
ET_UNUSED size_t size) const {
176+
Result<uint64_t> segment_end_offset = get_segment_end_offset(header_);
177+
if (!segment_end_offset.ok()) {
178+
return segment_end_offset.error();
179+
}
151180
Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data(
152181
key,
153182
flat_tensor_->named_data(),
154183
flat_tensor_->segments(),
155-
header_.segment_base_offset + header_.segment_data_size);
184+
segment_end_offset.get());
156185
if (!named_data.ok()) {
157186
return named_data.error();
158187
}

0 commit comments

Comments
 (0)