2121#include < executorch/runtime/core/span.h>
2222#include < executorch/runtime/platform/compiler.h>
2323
24+ #include < cinttypes>
25+
2426using executorch::runtime::Error;
2527using executorch::runtime::FreeableBuffer;
2628using executorch::runtime::Result;
@@ -52,7 +54,7 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
5254 flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data,
5355 const flatbuffers::Vector<
5456 flatbuffers::Offset<flat_tensor_flatbuffer::DataSegment>>* segments,
55- size_t segment_end_offset) {
57+ uint64_t segment_end_offset) {
5658 // Linear search by name.
5759 if (named_data == nullptr ) {
5860 return Error::NotFound;
@@ -81,19 +83,34 @@ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
8183 static_cast <uint64_t >(segments->Get (segment_index)->offset ()),
8284 static_cast <uint64_t >(segments->Get (segment_index)->size ()),
8385 &seg_end) &&
84- seg_end <= static_cast < uint64_t >( segment_end_offset) ,
86+ seg_end <= segment_end_offset,
8587 InvalidExternalData,
8688 " Invalid segment offset %" PRIu64
8789 " is larger than the segment_base_offset + segment_data_size %" PRIu64
8890 " ; malformed PTD file." ,
8991 segments->Get (segment_index)->offset (),
90- static_cast < uint64_t >( segment_end_offset) );
92+ segment_end_offset);
9193 return found;
9294 }
9395 }
9496 return Error::NotFound;
9597}
9698
99+ Result<uint64_t > get_segment_end_offset (const FlatTensorHeader& header) {
100+ uint64_t segment_end_offset = 0 ;
101+ ET_CHECK_OR_RETURN_ERROR (
102+ !c10::add_overflows (
103+ header.segment_base_offset ,
104+ header.segment_data_size ,
105+ &segment_end_offset),
106+ InvalidExternalData,
107+ " segment_base_offset %" PRIu64 " + segment_data_size %" PRIu64
108+ " overflows uint64_t; malformed PTD file." ,
109+ header.segment_base_offset ,
110+ header.segment_data_size );
111+ return segment_end_offset;
112+ }
113+
97114Result<const TensorLayout> create_tensor_layout (
98115 const flat_tensor_flatbuffer::TensorLayout* tensor_layout) {
99116 ScalarType scalar_type =
@@ -111,11 +128,15 @@ Result<const TensorLayout> create_tensor_layout(
111128
112129ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout (
113130 executorch::aten::string_view key) const {
131+ Result<uint64_t > segment_end_offset = get_segment_end_offset (header_);
132+ if (!segment_end_offset.ok ()) {
133+ return segment_end_offset.error ();
134+ }
114135 Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data (
115136 key,
116137 flat_tensor_->named_data (),
117138 flat_tensor_->segments (),
118- header_. segment_base_offset + header_. segment_data_size );
139+ segment_end_offset. get () );
119140 if (!named_data.ok ()) {
120141 return named_data.error ();
121142 }
@@ -124,11 +145,15 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_tensor_layout(
124145
125146ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
126147 executorch::aten::string_view key) const {
148+ Result<uint64_t > segment_end_offset = get_segment_end_offset (header_);
149+ if (!segment_end_offset.ok ()) {
150+ return segment_end_offset.error ();
151+ }
127152 Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data (
128153 key,
129154 flat_tensor_->named_data (),
130155 flat_tensor_->segments (),
131- header_. segment_base_offset + header_. segment_data_size );
156+ segment_end_offset. get () );
132157 if (!named_data.ok ()) {
133158 return named_data.error ();
134159 }
@@ -148,11 +173,15 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
148173 ET_UNUSED executorch::aten::string_view key,
149174 ET_UNUSED void * buffer,
150175 ET_UNUSED size_t size) const {
176+ Result<uint64_t > segment_end_offset = get_segment_end_offset (header_);
177+ if (!segment_end_offset.ok ()) {
178+ return segment_end_offset.error ();
179+ }
151180 Result<const flat_tensor_flatbuffer::NamedData*> named_data = get_named_data (
152181 key,
153182 flat_tensor_->named_data (),
154183 flat_tensor_->segments (),
155- header_. segment_base_offset + header_. segment_data_size );
184+ segment_end_offset. get () );
156185 if (!named_data.ok ()) {
157186 return named_data.error ();
158187 }
0 commit comments