Skip to content

Commit 8948c05

Browse files
authored
[core] Enhance single file storage bounds control (#35155)
### Details: - Applies overflow and underflow checks in stream manipulations and size/offset calculations. ### Tickets: - CVS-183610 ### AI Assistance: - AI assistance used: yes code completions --------- Signed-off-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>
1 parent 20361cc commit 8948c05

4 files changed

Lines changed: 196 additions & 72 deletions

File tree

src/inference/dev_api/openvino/runtime/single_file_storage.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,16 @@ class SingleFileStorage final : public ICacheManager, public IContextStore {
7171
std::filesystem::path m_file_path;
7272

7373
struct BlobInfo {
74-
std::streampos offset;
75-
std::streamoff size;
74+
uint64_t offset;
75+
uint64_t size;
7676
std::string model_name;
7777
};
7878
std::unordered_map<BlobIdType, BlobInfo> m_blob_index;
7979
std::shared_ptr<wsh::Context> m_shared_context;
8080
bool build_content_index(std::ifstream& stream);
8181

8282
static BlobIdType convert_blob_id(const std::string& blob_id);
83-
void write_blob_entry(std::ofstream& stream, BlobIdType blob_id, StreamWriter& writer);
83+
void write_blob_entry(std::fstream& stream, BlobIdType blob_id, StreamWriter& writer);
8484
bool has_blob_id(BlobIdType blob_id) const;
8585
};
8686
} // namespace ov::runtime

src/inference/src/dev/tlv_format.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "openvino/runtime/tlv_format.hpp"
66

7+
#include <limits>
78
#include <type_traits>
89
#include <utility>
910

@@ -48,6 +49,19 @@ static bool read_record(std::istream& stream, TLVTraits::TagType& tag, TLVTraits
4849
data.clear();
4950
return true;
5051
}
52+
constexpr auto size_limit =
53+
std::min(static_cast<TLVTraits::LengthType>(std::numeric_limits<std::streamsize>::max()),
54+
static_cast<TLVTraits::LengthType>(std::numeric_limits<size_t>::max()));
55+
if (size > size_limit) {
56+
return false;
57+
}
58+
const auto current_pos = stream.tellg();
59+
const auto stream_end = stream.seekg(0, std::ios::end).tellg();
60+
stream.seekg(current_pos);
61+
const auto remaining_offset = stream.good() ? static_cast<TLVTraits::LengthType>(stream_end - current_pos) : 0u;
62+
if (remaining_offset < size) {
63+
return false;
64+
}
5165
data.resize(size);
5266
stream.read(reinterpret_cast<char*>(data.data()), size);
5367
return stream.good();
@@ -75,7 +89,7 @@ bool scan_tlv_records(std::istream& stream, const TLVValueScanner& scanners) {
7589
return false;
7690
}
7791
stream.read(reinterpret_cast<char*>(&size), sizeof(size));
78-
if (!stream.good()) {
92+
if (!stream.good() || (stream_end - stream.tellg() < static_cast<std::streamoff>(size))) {
7993
return false;
8094
}
8195

@@ -84,9 +98,6 @@ bool scan_tlv_records(std::istream& stream, const TLVValueScanner& scanners) {
8498
return false;
8599
}
86100
} else {
87-
if (stream_end - stream.tellg() < static_cast<std::streamoff>(size) || !stream.good()) {
88-
return false;
89-
}
90101
stream.seekg(size, std::ios::cur);
91102
if (!stream.good()) {
92103
return false;

src/inference/src/single_file_storage.cpp

Lines changed: 60 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ bool read_tlv_string(std::istream& stream, std::string& str) {
4545
TLVTraits::TagType tag;
4646
TLVTraits::LengthType size;
4747
std::vector<char> buffer;
48-
const auto read = read_tlv_record(stream, tag, size, buffer);
49-
if (read) {
50-
OPENVINO_ASSERT(SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String);
48+
if (read_tlv_record(stream, tag, size, buffer) && SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String) {
5149
str = std::string{buffer.begin(), buffer.end()};
50+
return true;
51+
} else {
52+
return false;
5253
}
53-
return read;
5454
}
5555

5656
void write_padding(std::ostream& stream, uint64_t alignment) {
@@ -89,115 +89,114 @@ SingleFileStorage::SingleFileStorage(const std::filesystem::path& path)
8989
}
9090

9191
bool SingleFileStorage::build_content_index(std::ifstream& stream) {
92-
const auto current_pos = stream.tellg();
93-
const auto end_pos = stream.seekg(0, std::ios::end).tellg();
94-
stream.seekg(current_pos);
95-
const auto blob_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) {
92+
const auto blob_reader = [this](std::istream& s, TLVTraits::LengthType size) {
9693
if (size == 0) {
9794
return true;
9895
}
96+
constexpr auto header_size = sizeof(BlobIdType) + sizeof(PadSizeType);
97+
if (size < header_size) {
98+
return false;
99+
}
99100
BlobIdType id;
100101
PadSizeType padding_size;
101102
s.read(reinterpret_cast<char*>(&id), sizeof(id));
102103
s.read(reinterpret_cast<char*>(&padding_size), sizeof(padding_size));
103-
if (s.tellg() + static_cast<std::streamoff>(padding_size) > end_pos) {
104+
if (!s.good() || padding_size > size - header_size) {
104105
return false;
105106
}
106-
s.seekg(padding_size, std::ios::cur);
107-
if (!s.good()) {
107+
const auto blob_data_pos = s.seekg(padding_size, std::ios::cur).tellg();
108+
if (!s.good() || blob_data_pos < 0) {
108109
return false;
109110
}
110-
const auto blob_data_pos = s.tellg();
111-
const auto blob_data_size =
112-
static_cast<std::streamoff>(size - sizeof(id) - sizeof(padding_size) - padding_size);
113-
m_blob_index[id].offset = blob_data_pos;
114-
m_blob_index[id].size = blob_data_size;
115-
if (blob_data_pos + blob_data_size > end_pos) {
111+
const auto blob_data_size = static_cast<std::streamoff>(size - header_size - padding_size);
112+
s.seekg(blob_data_size, std::ios::cur);
113+
if (!s.good()) {
116114
return false;
117115
}
118-
s.seekg(blob_data_size, std::ios::cur);
119-
return s.good();
116+
m_blob_index[id].offset = static_cast<uint64_t>(blob_data_pos);
117+
m_blob_index[id].size = static_cast<uint64_t>(blob_data_size);
118+
return true;
120119
};
121120
const auto blob_map_reader = [this](std::istream& s, TLVTraits::LengthType size) {
122121
if (size == 0) {
123122
return true;
124123
}
124+
if (size < sizeof(BlobIdType) + sizeof(TLVTraits::TagType) + sizeof(TLVTraits::LengthType)) {
125+
return false;
126+
}
125127
BlobIdType id;
126128
s.read(reinterpret_cast<char*>(&id), sizeof(id));
127129
if (!s.good()) {
128130
return false;
129131
}
130132
if (std::string model_name; read_tlv_string(s, model_name)) {
131133
m_blob_index[id].model_name = std::move(model_name);
134+
return true;
135+
} else {
136+
return false;
132137
}
133-
return s.good();
134138
};
135-
const auto constant_meta_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) {
139+
const auto constant_meta_reader = [this](std::istream& s, TLVTraits::LengthType size) {
136140
if (size == 0) {
137141
return true;
138142
}
139-
if (s.tellg() + static_cast<std::streamoff>(size) > end_pos) {
143+
uint64_t source_id;
144+
if (size < sizeof(source_id)) {
140145
return false;
141146
}
142-
uint64_t source_id;
143147
s.read(reinterpret_cast<char*>(&source_id), sizeof(source_id));
144-
auto left_size = size - sizeof(source_id);
145-
146-
while (s.good() && left_size > 0) {
148+
auto remaining_size = size - sizeof(source_id);
149+
while (s.good() && remaining_size > 0) {
147150
uint64_t const_id, const_offset, const_size;
148151
uint8_t const_type;
149152
constexpr auto const_meta_size =
150153
sizeof(const_id) + sizeof(const_offset) + sizeof(const_size) + sizeof(const_type);
151-
if (left_size < const_meta_size) {
154+
if (remaining_size < const_meta_size) {
152155
return false;
153156
}
154-
left_size -= const_meta_size;
157+
remaining_size -= const_meta_size;
155158
s.read(reinterpret_cast<char*>(&const_id), sizeof(const_id));
156159
s.read(reinterpret_cast<char*>(&const_offset), sizeof(const_offset));
157160
s.read(reinterpret_cast<char*>(&const_size), sizeof(const_size));
158161
s.read(reinterpret_cast<char*>(&const_type), sizeof(const_type));
159-
if (!s.good()) {
160-
return false;
162+
if (s.good()) {
163+
m_shared_context->m_weight_registry[source_id][const_id] = {static_cast<size_t>(const_offset),
164+
static_cast<size_t>(const_size),
165+
element::Type_t{const_type}};
161166
}
162-
m_shared_context->m_weight_registry[source_id][const_id] = {static_cast<size_t>(const_offset),
163-
static_cast<size_t>(const_size),
164-
element::Type_t{const_type}};
165167
}
166168
return s.good();
167169
};
168-
const auto constant_source_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) {
170+
const auto weight_source_reader = [this](std::istream& s, TLVTraits::LengthType size) {
169171
if (size == 0) {
170172
return true;
171173
}
172-
DataIdType device_id, source_id;
173-
PadSizeType padding_size;
174-
if (size < sizeof(device_id) + sizeof(source_id) + sizeof(padding_size) ||
175-
s.tellg() + static_cast<std::streamoff>(size) > end_pos) {
174+
constexpr auto header_size = sizeof(DataIdType) + sizeof(DataIdType) + sizeof(PadSizeType);
175+
if (size < header_size) {
176176
return false;
177177
}
178+
DataIdType device_id, source_id;
179+
PadSizeType padding_size;
178180
s.read(reinterpret_cast<char*>(&device_id), sizeof(device_id));
179181
s.read(reinterpret_cast<char*>(&source_id), sizeof(source_id));
180182
s.read(reinterpret_cast<char*>(&padding_size), sizeof(padding_size));
181-
if (s.tellg() + static_cast<std::streamoff>(padding_size) > end_pos) {
183+
if (!s.good() || padding_size > size - header_size) {
182184
return false;
183185
}
184186
s.seekg(padding_size, std::ios::cur);
185187
if (!s.good()) {
186188
return false;
187189
}
188-
const auto weight_size = size - sizeof(device_id) - sizeof(source_id) - sizeof(padding_size) - padding_size;
190+
const auto weight_size = size - header_size - padding_size;
189191
m_shared_context->m_cache_sources[source_id] = {};
190-
if (s.tellg() + static_cast<std::streamoff>(weight_size) > end_pos) {
191-
return false;
192-
}
193192
s.seekg(weight_size, std::ios::cur);
194193
return s.good();
195194
};
196195
const TLVValueScanner scanners = {
197196
{static_cast<TLVTraits::TagType>(Tag::Blob), blob_reader},
198197
{static_cast<TLVTraits::TagType>(Tag::BlobMap), blob_map_reader},
199198
{static_cast<TLVTraits::TagType>(Tag::ConstantMeta), constant_meta_reader},
200-
{static_cast<TLVTraits::TagType>(Tag::WeightSource), constant_source_reader},
199+
{static_cast<TLVTraits::TagType>(Tag::WeightSource), weight_source_reader},
201200
};
202201
return scan_tlv_records(stream, scanners);
203202
}
@@ -210,7 +209,7 @@ bool SingleFileStorage::has_blob_id(BlobIdType blob_id) const {
210209
return m_blob_index.find(blob_id) != m_blob_index.end();
211210
}
212211

213-
void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_id, StreamWriter& writer) {
212+
void SingleFileStorage::write_blob_entry(std::fstream& stream, BlobIdType blob_id, StreamWriter& writer) {
214213
OPENVINO_ASSERT(!has_blob_id(blob_id), "Blob with id ", blob_id, " already exists in cache.");
215214

216215
std::streampos blob_pos;
@@ -220,8 +219,10 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_
220219
s.write(reinterpret_cast<const char*>(&blob_id), sizeof(blob_id));
221220
write_padding(s, blob_alignment);
222221
blob_pos = s.tellp();
222+
OPENVINO_ASSERT(blob_pos >= 0, "Invalid blob data position ", blob_pos, " for blob id ", blob_id);
223223
writer(s);
224224
blob_size = s.tellp() - blob_pos;
225+
OPENVINO_ASSERT(blob_size >= 0, "Invalid blob size ", blob_size, " for blob id ", blob_id);
225226
};
226227
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::Blob), blob_writer);
227228

@@ -233,12 +234,13 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_
233234
};
234235
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::BlobMap), blob_map_writer);
235236

236-
m_blob_index[blob_id] = {blob_pos, blob_size, std::move(model_name)};
237+
m_blob_index[blob_id] = {static_cast<uint64_t>(blob_pos), static_cast<uint64_t>(blob_size), std::move(model_name)};
237238
}
238239

239240
void SingleFileStorage::write_cache_entry(const std::string& blob_id, StreamWriter writer) {
240241
ScopedLocale plocal_C(LC_ALL, "C");
241-
std::ofstream stream(m_file_path, std::ios::binary | std::ios::in | std::ios::ate);
242+
std::fstream stream(m_file_path, std::ios::binary | std::ios::in | std::ios::out | std::ios::ate);
243+
OPENVINO_ASSERT(stream.good(), "Failed to open cache file ", m_file_path, " for writing blob id ", blob_id);
242244
write_blob_entry(stream, convert_blob_id(blob_id), writer);
243245
}
244246

@@ -250,12 +252,11 @@ void SingleFileStorage::read_cache_entry(const std::string& blob_id, bool enable
250252
if (std::filesystem::exists(m_file_path) && has_blob_id(cid)) {
251253
const auto& [blob_pos, blob_size, model_name] = m_blob_index[cid];
252254
if (enable_mmap) {
253-
// CVS-181859 Extend memory mapping helpers to suport partial file mapping
254255
CompiledBlobVariant compiled_blob{std::in_place_index<0>,
255-
ov::read_tensor_data(m_file_path,
256-
element::u8,
257-
{static_cast<PartialShape::value_type>(blob_size)},
258-
blob_pos)};
256+
read_tensor_data(m_file_path,
257+
element::u8,
258+
{static_cast<PartialShape::value_type>(blob_size)},
259+
blob_pos)};
259260
reader(compiled_blob);
260261
} else {
261262
std::ifstream stream(m_file_path, std::ios::binary);
@@ -312,19 +313,18 @@ void SingleFileStorage::write_context(const weight_sharing::Context& context) {
312313
}
313314
}
314315
for (const auto& cache_registry : delta_cache_sources) {
315-
const auto weight_source_writer = [&](std::ostream& s) {
316-
const auto& [source_id, weight_buffer] = cache_registry;
317-
if (auto buf = weight_buffer.m_weights.lock()) {
316+
const auto& source_id = cache_registry.first;
317+
const auto& weight_buffer = cache_registry.second;
318+
if (auto weights = weight_buffer.m_weights.lock()) {
319+
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::WeightSource), [&](std::ostream& s) {
318320
const auto device_id = static_cast<uint64_t>(std::strtoul(weight_buffer.m_device.c_str(), nullptr, 10));
319321
s.write(reinterpret_cast<const char*>(&device_id), sizeof(device_id));
320322
s.write(reinterpret_cast<const char*>(&source_id), sizeof(source_id));
321323
write_padding(s, blob_alignment);
322-
s.write(reinterpret_cast<const char*>(buf->get_ptr()), buf->size());
323-
}
324-
324+
s.write(weights->get_ptr<char>(), weights->size());
325+
});
325326
m_shared_context->m_cache_sources[source_id] = weight_buffer;
326-
};
327-
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::WeightSource), weight_source_writer);
327+
}
328328
}
329329

330330
for (const auto& [source_id, buffer] : context.m_runtime_sources) {

0 commit comments

Comments
 (0)