@@ -45,12 +45,12 @@ bool read_tlv_string(std::istream& stream, std::string& str) {
4545 TLVTraits::TagType tag;
4646 TLVTraits::LengthType size;
4747 std::vector<char > buffer;
48- const auto read = read_tlv_record (stream, tag, size, buffer);
49- if (read) {
50- OPENVINO_ASSERT (SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String);
48+ if (read_tlv_record (stream, tag, size, buffer) && SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String) {
5149 str = std::string{buffer.begin (), buffer.end ()};
50+ return true ;
51+ } else {
52+ return false ;
5253 }
53- return read;
5454}
5555
5656void write_padding (std::ostream& stream, uint64_t alignment) {
@@ -89,115 +89,114 @@ SingleFileStorage::SingleFileStorage(const std::filesystem::path& path)
8989}
9090
9191bool SingleFileStorage::build_content_index (std::ifstream& stream) {
92- const auto current_pos = stream.tellg ();
93- const auto end_pos = stream.seekg (0 , std::ios::end).tellg ();
94- stream.seekg (current_pos);
95- const auto blob_reader = [this , end_pos](std::istream& s, TLVTraits::LengthType size) {
92+ const auto blob_reader = [this ](std::istream& s, TLVTraits::LengthType size) {
9693 if (size == 0 ) {
9794 return true ;
9895 }
96+ constexpr auto header_size = sizeof (BlobIdType) + sizeof (PadSizeType);
97+ if (size < header_size) {
98+ return false ;
99+ }
99100 BlobIdType id;
100101 PadSizeType padding_size;
101102 s.read (reinterpret_cast <char *>(&id), sizeof (id));
102103 s.read (reinterpret_cast <char *>(&padding_size), sizeof (padding_size));
103- if (s. tellg () + static_cast <std::streamoff>( padding_size) > end_pos ) {
104+ if (!s. good () || padding_size > size - header_size ) {
104105 return false ;
105106 }
106- s.seekg (padding_size, std::ios::cur);
107- if (!s.good ()) {
107+ const auto blob_data_pos = s.seekg (padding_size, std::ios::cur). tellg ( );
108+ if (!s.good () || blob_data_pos < 0 ) {
108109 return false ;
109110 }
110- const auto blob_data_pos = s.tellg ();
111- const auto blob_data_size =
112- static_cast <std::streamoff>(size - sizeof (id) - sizeof (padding_size) - padding_size);
113- m_blob_index[id].offset = blob_data_pos;
114- m_blob_index[id].size = blob_data_size;
115- if (blob_data_pos + blob_data_size > end_pos) {
111+ const auto blob_data_size = static_cast <std::streamoff>(size - header_size - padding_size);
112+ s.seekg (blob_data_size, std::ios::cur);
113+ if (!s.good ()) {
116114 return false ;
117115 }
118- s.seekg (blob_data_size, std::ios::cur);
119- return s.good ();
116+ m_blob_index[id].offset = static_cast <uint64_t >(blob_data_pos);
117+ m_blob_index[id].size = static_cast <uint64_t >(blob_data_size);
118+ return true ;
120119 };
121120 const auto blob_map_reader = [this ](std::istream& s, TLVTraits::LengthType size) {
122121 if (size == 0 ) {
123122 return true ;
124123 }
124+ if (size < sizeof (BlobIdType) + sizeof (TLVTraits::TagType) + sizeof (TLVTraits::LengthType)) {
125+ return false ;
126+ }
125127 BlobIdType id;
126128 s.read (reinterpret_cast <char *>(&id), sizeof (id));
127129 if (!s.good ()) {
128130 return false ;
129131 }
130132 if (std::string model_name; read_tlv_string (s, model_name)) {
131133 m_blob_index[id].model_name = std::move (model_name);
134+ return true ;
135+ } else {
136+ return false ;
132137 }
133- return s.good ();
134138 };
135- const auto constant_meta_reader = [this , end_pos ](std::istream& s, TLVTraits::LengthType size) {
139+ const auto constant_meta_reader = [this ](std::istream& s, TLVTraits::LengthType size) {
136140 if (size == 0 ) {
137141 return true ;
138142 }
139- if (s.tellg () + static_cast <std::streamoff>(size) > end_pos) {
143+ uint64_t source_id;
144+ if (size < sizeof (source_id)) {
140145 return false ;
141146 }
142- uint64_t source_id;
143147 s.read (reinterpret_cast <char *>(&source_id), sizeof (source_id));
144- auto left_size = size - sizeof (source_id);
145-
146- while (s.good () && left_size > 0 ) {
148+ auto remaining_size = size - sizeof (source_id);
149+ while (s.good () && remaining_size > 0 ) {
147150 uint64_t const_id, const_offset, const_size;
148151 uint8_t const_type;
149152 constexpr auto const_meta_size =
150153 sizeof (const_id) + sizeof (const_offset) + sizeof (const_size) + sizeof (const_type);
151- if (left_size < const_meta_size) {
154+ if (remaining_size < const_meta_size) {
152155 return false ;
153156 }
154- left_size -= const_meta_size;
157+ remaining_size -= const_meta_size;
155158 s.read (reinterpret_cast <char *>(&const_id), sizeof (const_id));
156159 s.read (reinterpret_cast <char *>(&const_offset), sizeof (const_offset));
157160 s.read (reinterpret_cast <char *>(&const_size), sizeof (const_size));
158161 s.read (reinterpret_cast <char *>(&const_type), sizeof (const_type));
159- if (!s.good ()) {
160- return false ;
162+ if (s.good ()) {
163+ m_shared_context->m_weight_registry [source_id][const_id] = {static_cast <size_t >(const_offset),
164+ static_cast <size_t >(const_size),
165+ element::Type_t{const_type}};
161166 }
162- m_shared_context->m_weight_registry [source_id][const_id] = {static_cast <size_t >(const_offset),
163- static_cast <size_t >(const_size),
164- element::Type_t{const_type}};
165167 }
166168 return s.good ();
167169 };
168- const auto constant_source_reader = [this , end_pos ](std::istream& s, TLVTraits::LengthType size) {
170+ const auto weight_source_reader = [this ](std::istream& s, TLVTraits::LengthType size) {
169171 if (size == 0 ) {
170172 return true ;
171173 }
172- DataIdType device_id, source_id;
173- PadSizeType padding_size;
174- if (size < sizeof (device_id) + sizeof (source_id) + sizeof (padding_size) ||
175- s.tellg () + static_cast <std::streamoff>(size) > end_pos) {
174+ constexpr auto header_size = sizeof (DataIdType) + sizeof (DataIdType) + sizeof (PadSizeType);
175+ if (size < header_size) {
176176 return false ;
177177 }
178+ DataIdType device_id, source_id;
179+ PadSizeType padding_size;
178180 s.read (reinterpret_cast <char *>(&device_id), sizeof (device_id));
179181 s.read (reinterpret_cast <char *>(&source_id), sizeof (source_id));
180182 s.read (reinterpret_cast <char *>(&padding_size), sizeof (padding_size));
181- if (s. tellg () + static_cast <std::streamoff>( padding_size) > end_pos ) {
183+ if (!s. good () || padding_size > size - header_size ) {
182184 return false ;
183185 }
184186 s.seekg (padding_size, std::ios::cur);
185187 if (!s.good ()) {
186188 return false ;
187189 }
188- const auto weight_size = size - sizeof (device_id) - sizeof (source_id) - sizeof (padding_size) - padding_size;
190+ const auto weight_size = size - header_size - padding_size;
189191 m_shared_context->m_cache_sources [source_id] = {};
190- if (s.tellg () + static_cast <std::streamoff>(weight_size) > end_pos) {
191- return false ;
192- }
193192 s.seekg (weight_size, std::ios::cur);
194193 return s.good ();
195194 };
196195 const TLVValueScanner scanners = {
197196 {static_cast <TLVTraits::TagType>(Tag::Blob), blob_reader},
198197 {static_cast <TLVTraits::TagType>(Tag::BlobMap), blob_map_reader},
199198 {static_cast <TLVTraits::TagType>(Tag::ConstantMeta), constant_meta_reader},
200- {static_cast <TLVTraits::TagType>(Tag::WeightSource), constant_source_reader },
199+ {static_cast <TLVTraits::TagType>(Tag::WeightSource), weight_source_reader },
201200 };
202201 return scan_tlv_records (stream, scanners);
203202}
@@ -210,7 +209,7 @@ bool SingleFileStorage::has_blob_id(BlobIdType blob_id) const {
210209 return m_blob_index.find (blob_id) != m_blob_index.end ();
211210}
212211
213- void SingleFileStorage::write_blob_entry (std::ofstream & stream, BlobIdType blob_id, StreamWriter& writer) {
212+ void SingleFileStorage::write_blob_entry (std::fstream & stream, BlobIdType blob_id, StreamWriter& writer) {
214213 OPENVINO_ASSERT (!has_blob_id (blob_id), " Blob with id " , blob_id, " already exists in cache." );
215214
216215 std::streampos blob_pos;
@@ -220,8 +219,10 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_
220219 s.write (reinterpret_cast <const char *>(&blob_id), sizeof (blob_id));
221220 write_padding (s, blob_alignment);
222221 blob_pos = s.tellp ();
222+ OPENVINO_ASSERT (blob_pos >= 0 , " Invalid blob data position " , blob_pos, " for blob id " , blob_id);
223223 writer (s);
224224 blob_size = s.tellp () - blob_pos;
225+ OPENVINO_ASSERT (blob_size >= 0 , " Invalid blob size " , blob_size, " for blob id " , blob_id);
225226 };
226227 write_tlv_record (stream, static_cast <TLVTraits::TagType>(Tag::Blob), blob_writer);
227228
@@ -233,12 +234,13 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_
233234 };
234235 write_tlv_record (stream, static_cast <TLVTraits::TagType>(Tag::BlobMap), blob_map_writer);
235236
236- m_blob_index[blob_id] = {blob_pos, blob_size, std::move (model_name)};
237+ m_blob_index[blob_id] = {static_cast < uint64_t >( blob_pos), static_cast < uint64_t >( blob_size) , std::move (model_name)};
237238}
238239
239240void SingleFileStorage::write_cache_entry (const std::string& blob_id, StreamWriter writer) {
240241 ScopedLocale plocal_C (LC_ALL, " C" );
241- std::ofstream stream (m_file_path, std::ios::binary | std::ios::in | std::ios::ate);
242+ std::fstream stream (m_file_path, std::ios::binary | std::ios::in | std::ios::out | std::ios::ate);
243+ OPENVINO_ASSERT (stream.good (), " Failed to open cache file " , m_file_path, " for writing blob id " , blob_id);
242244 write_blob_entry (stream, convert_blob_id (blob_id), writer);
243245}
244246
@@ -250,12 +252,11 @@ void SingleFileStorage::read_cache_entry(const std::string& blob_id, bool enable
250252 if (std::filesystem::exists (m_file_path) && has_blob_id (cid)) {
251253 const auto & [blob_pos, blob_size, model_name] = m_blob_index[cid];
252254 if (enable_mmap) {
253- // CVS-181859 Extend memory mapping helpers to suport partial file mapping
254255 CompiledBlobVariant compiled_blob{std::in_place_index<0 >,
255- ov:: read_tensor_data (m_file_path,
256- element::u8 ,
257- {static_cast <PartialShape::value_type>(blob_size)},
258- blob_pos)};
256+ read_tensor_data (m_file_path,
257+ element::u8 ,
258+ {static_cast <PartialShape::value_type>(blob_size)},
259+ blob_pos)};
259260 reader (compiled_blob);
260261 } else {
261262 std::ifstream stream (m_file_path, std::ios::binary);
@@ -312,19 +313,18 @@ void SingleFileStorage::write_context(const weight_sharing::Context& context) {
312313 }
313314 }
314315 for (const auto & cache_registry : delta_cache_sources) {
315- const auto weight_source_writer = [&](std::ostream& s) {
316- const auto & [source_id, weight_buffer] = cache_registry;
317- if (auto buf = weight_buffer.m_weights .lock ()) {
316+ const auto & source_id = cache_registry.first ;
317+ const auto & weight_buffer = cache_registry.second ;
318+ if (auto weights = weight_buffer.m_weights .lock ()) {
319+ write_tlv_record (stream, static_cast <TLVTraits::TagType>(Tag::WeightSource), [&](std::ostream& s) {
318320 const auto device_id = static_cast <uint64_t >(std::strtoul (weight_buffer.m_device .c_str (), nullptr , 10 ));
319321 s.write (reinterpret_cast <const char *>(&device_id), sizeof (device_id));
320322 s.write (reinterpret_cast <const char *>(&source_id), sizeof (source_id));
321323 write_padding (s, blob_alignment);
322- s.write (reinterpret_cast <const char *>(buf->get_ptr ()), buf->size ());
323- }
324-
324+ s.write (weights->get_ptr <char >(), weights->size ());
325+ });
325326 m_shared_context->m_cache_sources [source_id] = weight_buffer;
326- };
327- write_tlv_record (stream, static_cast <TLVTraits::TagType>(Tag::WeightSource), weight_source_writer);
327+ }
328328 }
329329
330330 for (const auto & [source_id, buffer] : context.m_runtime_sources ) {
0 commit comments