diff --git a/src/inference/dev_api/openvino/runtime/single_file_storage.hpp b/src/inference/dev_api/openvino/runtime/single_file_storage.hpp index d7be109b2dd7ae..d79f2cb7a95df0 100644 --- a/src/inference/dev_api/openvino/runtime/single_file_storage.hpp +++ b/src/inference/dev_api/openvino/runtime/single_file_storage.hpp @@ -71,8 +71,8 @@ class SingleFileStorage final : public ICacheManager, public IContextStore { std::filesystem::path m_file_path; struct BlobInfo { - std::streampos offset; - std::streamoff size; + uint64_t offset; + uint64_t size; std::string model_name; }; std::unordered_map m_blob_index; @@ -80,7 +80,7 @@ class SingleFileStorage final : public ICacheManager, public IContextStore { bool build_content_index(std::ifstream& stream); static BlobIdType convert_blob_id(const std::string& blob_id); - void write_blob_entry(std::ofstream& stream, BlobIdType blob_id, StreamWriter& writer); + void write_blob_entry(std::fstream& stream, BlobIdType blob_id, StreamWriter& writer); bool has_blob_id(BlobIdType blob_id) const; }; } // namespace ov::runtime diff --git a/src/inference/src/dev/tlv_format.cpp b/src/inference/src/dev/tlv_format.cpp index 96a449ff8cc716..b06e4e10f0cda6 100644 --- a/src/inference/src/dev/tlv_format.cpp +++ b/src/inference/src/dev/tlv_format.cpp @@ -4,6 +4,7 @@ #include "openvino/runtime/tlv_format.hpp" +#include #include #include @@ -48,6 +49,19 @@ static bool read_record(std::istream& stream, TLVTraits::TagType& tag, TLVTraits data.clear(); return true; } + constexpr auto size_limit = + std::min(static_cast(std::numeric_limits::max()), + static_cast(std::numeric_limits::max())); + if (size > size_limit) { + return false; + } + const auto current_pos = stream.tellg(); + const auto stream_end = stream.seekg(0, std::ios::end).tellg(); + stream.seekg(current_pos); + const auto remaining_offset = stream.good() ? static_cast(stream_end - current_pos) : 0u; + if (remaining_offset < size) { + return false; + } data.resize(size); stream.read(reinterpret_cast(data.data()), size); return stream.good(); @@ -75,7 +89,7 @@ bool scan_tlv_records(std::istream& stream, const TLVValueScanner& scanners) { return false; } stream.read(reinterpret_cast(&size), sizeof(size)); - if (!stream.good()) { + if (!stream.good() || (stream_end - stream.tellg() < static_cast(size))) { return false; } @@ -84,9 +98,6 @@ bool scan_tlv_records(std::istream& stream, const TLVValueScanner& scanners) { return false; } } else { - if (stream_end - stream.tellg() < static_cast(size) || !stream.good()) { - return false; - } stream.seekg(size, std::ios::cur); if (!stream.good()) { return false; diff --git a/src/inference/src/single_file_storage.cpp b/src/inference/src/single_file_storage.cpp index 648c6e646e6ab2..6c846a055d20cb 100644 --- a/src/inference/src/single_file_storage.cpp +++ b/src/inference/src/single_file_storage.cpp @@ -45,12 +45,12 @@ bool read_tlv_string(std::istream& stream, std::string& str) { TLVTraits::TagType tag; TLVTraits::LengthType size; std::vector buffer; - const auto read = read_tlv_record(stream, tag, size, buffer); - if (read) { - OPENVINO_ASSERT(SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String); + if (read_tlv_record(stream, tag, size, buffer) && SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String) { str = std::string{buffer.begin(), buffer.end()}; + return true; + } else { + return false; } - return read; } void write_padding(std::ostream& stream, uint64_t alignment) { @@ -89,39 +89,41 @@ SingleFileStorage::SingleFileStorage(const std::filesystem::path& path) } bool SingleFileStorage::build_content_index(std::ifstream& stream) { - const auto current_pos = stream.tellg(); - const auto end_pos = stream.seekg(0, std::ios::end).tellg(); - stream.seekg(current_pos); - const auto blob_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) { + const auto blob_reader = [this](std::istream& s, TLVTraits::LengthType size) { if (size == 0) { return true; } + constexpr auto header_size = sizeof(BlobIdType) + sizeof(PadSizeType); + if (size < header_size) { + return false; + } BlobIdType id; PadSizeType padding_size; s.read(reinterpret_cast(&id), sizeof(id)); s.read(reinterpret_cast(&padding_size), sizeof(padding_size)); - if (s.tellg() + static_cast(padding_size) > end_pos) { + if (!s.good() || padding_size > size - header_size) { return false; } - s.seekg(padding_size, std::ios::cur); - if (!s.good()) { + const auto blob_data_pos = s.seekg(padding_size, std::ios::cur).tellg(); + if (!s.good() || blob_data_pos < 0) { return false; } - const auto blob_data_pos = s.tellg(); - const auto blob_data_size = - static_cast(size - sizeof(id) - sizeof(padding_size) - padding_size); - m_blob_index[id].offset = blob_data_pos; - m_blob_index[id].size = blob_data_size; - if (blob_data_pos + blob_data_size > end_pos) { + const auto blob_data_size = static_cast(size - header_size - padding_size); + s.seekg(blob_data_size, std::ios::cur); + if (!s.good()) { return false; } - s.seekg(blob_data_size, std::ios::cur); - return s.good(); + m_blob_index[id].offset = static_cast(blob_data_pos); + m_blob_index[id].size = static_cast(blob_data_size); + return true; }; const auto blob_map_reader = [this](std::istream& s, TLVTraits::LengthType size) { if (size == 0) { return true; } + if (size < sizeof(BlobIdType) + sizeof(TLVTraits::TagType) + sizeof(TLVTraits::LengthType)) { + return false; + } BlobIdType id; s.read(reinterpret_cast(&id), sizeof(id)); if (!s.good()) { @@ -129,67 +131,64 @@ bool SingleFileStorage::build_content_index(std::ifstream& stream) { } if (std::string model_name; read_tlv_string(s, model_name)) { m_blob_index[id].model_name = std::move(model_name); + return true; + } else { + return false; } - return s.good(); }; - const auto constant_meta_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) { + const auto constant_meta_reader = [this](std::istream& s, TLVTraits::LengthType size) { if (size == 0) { return true; } - if (s.tellg() + static_cast(size) > end_pos) { + uint64_t source_id; + if (size < sizeof(source_id)) { return false; } - uint64_t source_id; s.read(reinterpret_cast(&source_id), sizeof(source_id)); - auto left_size = size - sizeof(source_id); - - while (s.good() && left_size > 0) { + auto remaining_size = size - sizeof(source_id); + while (s.good() && remaining_size > 0) { uint64_t const_id, const_offset, const_size; uint8_t const_type; constexpr auto const_meta_size = sizeof(const_id) + sizeof(const_offset) + sizeof(const_size) + sizeof(const_type); - if (left_size < const_meta_size) { + if (remaining_size < const_meta_size) { return false; } - left_size -= const_meta_size; + remaining_size -= const_meta_size; s.read(reinterpret_cast(&const_id), sizeof(const_id)); s.read(reinterpret_cast(&const_offset), sizeof(const_offset)); s.read(reinterpret_cast(&const_size), sizeof(const_size)); s.read(reinterpret_cast(&const_type), sizeof(const_type)); - if (!s.good()) { - return false; + if (s.good()) { + m_shared_context->m_weight_registry[source_id][const_id] = {static_cast(const_offset), + static_cast(const_size), + element::Type_t{const_type}}; } - m_shared_context->m_weight_registry[source_id][const_id] = {static_cast(const_offset), - static_cast(const_size), - element::Type_t{const_type}}; } return s.good(); }; - const auto constant_source_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) { + const auto weight_source_reader = [this](std::istream& s, TLVTraits::LengthType size) { if (size == 0) { return true; } - DataIdType device_id, source_id; - PadSizeType padding_size; - if (size < sizeof(device_id) + sizeof(source_id) + sizeof(padding_size) || - s.tellg() + static_cast(size) > end_pos) { + constexpr auto header_size = sizeof(DataIdType) + sizeof(DataIdType) + sizeof(PadSizeType); + if (size < header_size) { return false; } + DataIdType device_id, source_id; + PadSizeType padding_size; s.read(reinterpret_cast(&device_id), sizeof(device_id)); s.read(reinterpret_cast(&source_id), sizeof(source_id)); s.read(reinterpret_cast(&padding_size), sizeof(padding_size)); - if (s.tellg() + static_cast(padding_size) > end_pos) { + if (!s.good() || padding_size > size - header_size) { return false; } s.seekg(padding_size, std::ios::cur); if (!s.good()) { return false; } - const auto weight_size = size - sizeof(device_id) - sizeof(source_id) - sizeof(padding_size) - padding_size; + const auto weight_size = size - header_size - padding_size; m_shared_context->m_cache_sources[source_id] = {}; - if (s.tellg() + static_cast(weight_size) > end_pos) { - return false; - } s.seekg(weight_size, std::ios::cur); return s.good(); }; @@ -197,7 +196,7 @@ bool SingleFileStorage::build_content_index(std::ifstream& stream) { {static_cast(Tag::Blob), blob_reader}, {static_cast(Tag::BlobMap), blob_map_reader}, {static_cast(Tag::ConstantMeta), constant_meta_reader}, - {static_cast(Tag::WeightSource), constant_source_reader}, + {static_cast(Tag::WeightSource), weight_source_reader}, }; return scan_tlv_records(stream, scanners); } @@ -210,7 +209,7 @@ bool SingleFileStorage::has_blob_id(BlobIdType blob_id) const { return m_blob_index.find(blob_id) != m_blob_index.end(); } -void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_id, StreamWriter& writer) { +void SingleFileStorage::write_blob_entry(std::fstream& stream, BlobIdType blob_id, StreamWriter& writer) { OPENVINO_ASSERT(!has_blob_id(blob_id), "Blob with id ", blob_id, " already exists in cache."); std::streampos blob_pos; @@ -220,8 +219,10 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_ s.write(reinterpret_cast(&blob_id), sizeof(blob_id)); write_padding(s, blob_alignment); blob_pos = s.tellp(); + OPENVINO_ASSERT(blob_pos >= 0, "Invalid blob data position ", blob_pos, " for blob id ", blob_id); writer(s); blob_size = s.tellp() - blob_pos; + OPENVINO_ASSERT(blob_size >= 0, "Invalid blob size ", blob_size, " for blob id ", blob_id); }; write_tlv_record(stream, static_cast(Tag::Blob), blob_writer); @@ -233,12 +234,13 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_ }; write_tlv_record(stream, static_cast(Tag::BlobMap), blob_map_writer); - m_blob_index[blob_id] = {blob_pos, blob_size, std::move(model_name)}; + m_blob_index[blob_id] = {static_cast(blob_pos), static_cast(blob_size), std::move(model_name)}; } void SingleFileStorage::write_cache_entry(const std::string& blob_id, StreamWriter writer) { ScopedLocale plocal_C(LC_ALL, "C"); - std::ofstream stream(m_file_path, std::ios::binary | std::ios::in | std::ios::ate); + std::fstream stream(m_file_path, std::ios::binary | std::ios::in | std::ios::out | std::ios::ate); + OPENVINO_ASSERT(stream.good(), "Failed to open cache file ", m_file_path, " for writing blob id ", blob_id); write_blob_entry(stream, convert_blob_id(blob_id), writer); } @@ -250,12 +252,11 @@ void SingleFileStorage::read_cache_entry(const std::string& blob_id, bool enable if (std::filesystem::exists(m_file_path) && has_blob_id(cid)) { const auto& [blob_pos, blob_size, model_name] = m_blob_index[cid]; if (enable_mmap) { - // CVS-181859 Extend memory mapping helpers to suport partial file mapping CompiledBlobVariant compiled_blob{std::in_place_index<0>, - ov::read_tensor_data(m_file_path, - element::u8, - {static_cast(blob_size)}, - blob_pos)}; + read_tensor_data(m_file_path, + element::u8, + {static_cast(blob_size)}, + blob_pos)}; reader(compiled_blob); } else { std::ifstream stream(m_file_path, std::ios::binary); @@ -312,19 +313,18 @@ void SingleFileStorage::write_context(const weight_sharing::Context& context) { } } for (const auto& cache_registry : delta_cache_sources) { - const auto weight_source_writer = [&](std::ostream& s) { - const auto& [source_id, weight_buffer] = cache_registry; - if (auto buf = weight_buffer.m_weights.lock()) { + const auto& source_id = cache_registry.first; + const auto& weight_buffer = cache_registry.second; + if (auto weights = weight_buffer.m_weights.lock()) { + write_tlv_record(stream, static_cast(Tag::WeightSource), [&](std::ostream& s) { const auto device_id = static_cast(std::strtoul(weight_buffer.m_device.c_str(), nullptr, 10)); s.write(reinterpret_cast(&device_id), sizeof(device_id)); s.write(reinterpret_cast(&source_id), sizeof(source_id)); write_padding(s, blob_alignment); - s.write(reinterpret_cast(buf->get_ptr()), buf->size()); - } - + s.write(weights->get_ptr(), weights->size()); + }); m_shared_context->m_cache_sources[source_id] = weight_buffer; - }; - write_tlv_record(stream, static_cast(Tag::WeightSource), weight_source_writer); + } } for (const auto& [source_id, buffer] : context.m_runtime_sources) { diff --git a/src/inference/tests/unit/single_file_storage.cpp b/src/inference/tests/unit/single_file_storage.cpp index 2b9457516e547e..955990cf49ccc5 100644 --- a/src/inference/tests/unit/single_file_storage.cpp +++ b/src/inference/tests/unit/single_file_storage.cpp @@ -10,11 +10,13 @@ #include "common_test_utils/file_utils.hpp" #include "common_test_utils/test_assertions.hpp" #include "openvino/runtime/aligned_buffer.hpp" +#include "openvino/runtime/tlv_format.hpp" #include "openvino/util/mmap_object.hpp" namespace ov::test { using runtime::SingleFileStorage; +using runtime::TLVTraits; namespace { constexpr uint64_t version_size() { @@ -22,7 +24,12 @@ constexpr uint64_t version_size() { } } // namespace -class SingleFileStorageTest : public ::testing::Test { +struct SingleFileStorageTestParam { + SingleFileStorage::Tag tag; + runtime::TLVValueWriter writer; +}; + +class SingleFileStorageTest : public ::testing::TestWithParam { protected: std::filesystem::path m_file_path; std::unique_ptr m_storage; @@ -31,7 +38,7 @@ class SingleFileStorageTest : public ::testing::Test { m_file_path = ov::test::utils::generateTestFilePrefix() + ".bin"; m_storage = std::make_unique(m_file_path); m_storage->initialize(); - ASSERT_TRUE(std::filesystem::exists(m_file_path)); + ASSERT_TRUE(util::file_exists(m_file_path)); } void TearDown() override { @@ -88,7 +95,6 @@ TEST_F(SingleFileStorageTest, WriteReadCacheEntry) { storage.read_cache_entry(blob_id, true, [&](const ICacheManager::CompiledBlobVariant& compiled_blob) { ASSERT_TRUE(std::holds_alternative(compiled_blob)); ++read_count; - // CVS-181859 Check support for multimap memory mapping auto& tensor = std::get(compiled_blob); ASSERT_EQ(tensor.get_byte_size(), blob_data.size()); std::vector read_data(blob_data.size()); @@ -124,7 +130,7 @@ TEST_F(SingleFileStorageTest, BlobAlignment) { while (stream.good() && stream.tellg() < stream_end) { SingleFileStorage::Tag tag; - runtime::TLVTraits::LengthType length; + TLVTraits::LengthType length; stream.read(reinterpret_cast(&tag), sizeof(tag)); ASSERT_TRUE(stream.good()); stream.read(reinterpret_cast(&length), sizeof(length)); @@ -262,7 +268,7 @@ TEST_F(SingleFileStorageTest, ContextWeightSourceWrite) { while (stream.good() && stream.tellg() < stream_end) { SingleFileStorage::Tag tag; - runtime::TLVTraits::LengthType length; + TLVTraits::LengthType length; stream.read(reinterpret_cast(&tag), sizeof(tag)); ASSERT_TRUE(stream.good()); stream.read(reinterpret_cast(&length), sizeof(length)); @@ -311,4 +317,111 @@ TEST_F(SingleFileStorageTest, ContextWeightSourceAppendDelta) { EXPECT_EQ(file_size_after_second_write, file_size_after_first_write) << "Rewriting the same context should not increase file size"; } + +TEST_F(SingleFileStorageTest, WriterMisposition) { + OV_EXPECT_THROW(m_storage->write_cache_entry("42", + [&](std::ostream& s) { + s.seekp(-1, std::ios::cur); + }), + AssertFailure, + ::testing::HasSubstr("Invalid blob size")); + OV_EXPECT_THROW(m_storage->write_cache_entry("41", + [&](std::ostream& s) { + s.seekp(0, std::ios::beg); + }), + AssertFailure, + ::testing::HasSubstr("Invalid blob size")); +} + +TEST_P(SingleFileStorageTest, WrongSizeWritten) { + m_storage.reset(); + std::fstream fs(m_file_path, std::ios::binary | std::ios::in | std::ios::out | std::ios::ate); + const auto& [tag, writer] = GetParam(); + runtime::write_tlv_record(fs, static_cast(tag), writer); + { + /* Append ignoreable data to fill the file allowing SingleFileStorage::initialize() to "read beyond" the + * malformed record. Otherwise stream not good check would trigger expected exception. */ + constexpr size_t sz = 100 * (sizeof(TLVTraits::TagType) + sizeof(TLVTraits::LengthType)); + std::vector data(sz, 0); + fs.write(data.data(), data.size()); + } + fs.close(); + OV_EXPECT_THROW(SingleFileStorage{m_file_path}.initialize(), + AssertFailure, + ::testing::HasSubstr("cache file may be corrupted")); +} +using sfstp = SingleFileStorageTestParam; +INSTANTIATE_TEST_SUITE_P( + Initialize, + SingleFileStorageTest, + ::testing::ValuesIn({sfstp{SingleFileStorage::Tag::Blob, + [](std::ostream& s) { + s.put('a'); + }}, + sfstp{SingleFileStorage::Tag::Blob, + [](std::ostream& s) { + const SingleFileStorage::BlobIdType id = 0x7531; + const SingleFileStorage::PadSizeType padding_size = 0x17; + s.write(reinterpret_cast(&id), sizeof(id)); + s.write(reinterpret_cast(&padding_size), sizeof(padding_size)); + const std::vector too_short_padding(padding_size - 1, 0); + s.write(too_short_padding.data(), too_short_padding.size()); + }}, + sfstp{SingleFileStorage::Tag::BlobMap, + [](std::ostream& s) { + const std::string text = "test"; + s.write(text.data(), text.size()); + }}, + sfstp{SingleFileStorage::Tag::BlobMap, + [](std::ostream& s) { + const SingleFileStorage::BlobIdType id = 07531; + const TLVTraits::TagType tag = + static_cast(SingleFileStorage::Tag::String); + const std::string text = "test"; + const TLVTraits::LengthType length = text.size() - 1; + s.write(reinterpret_cast(&id), sizeof(id)); + s.write(reinterpret_cast(&tag), sizeof(tag)); + s.write(reinterpret_cast(&length), sizeof(length)); + s.write(text.data(), text.size()); + }}, + sfstp{SingleFileStorage::Tag::BlobMap, + [](std::ostream& s) { + const SingleFileStorage::BlobIdType id = 07531; + const TLVTraits::TagType tag = + static_cast(SingleFileStorage::Tag::String); + const std::string text = "test"; + const TLVTraits::LengthType length = text.size() + 1; + s.write(reinterpret_cast(&id), sizeof(id)); + s.write(reinterpret_cast(&tag), sizeof(tag)); + s.write(reinterpret_cast(&length), sizeof(length)); + s.write(text.data(), text.size()); + }}, + sfstp{SingleFileStorage::Tag::ConstantMeta, + [](std::ostream& s) { + s.put('b'); + }}, + sfstp{SingleFileStorage::Tag::ConstantMeta, + [](std::ostream& s) { + const std::vector data(4 * sizeof(uint64_t) + sizeof(uint8_t) - 1, 0); + s.write(data.data(), data.size()); + }}, + sfstp{SingleFileStorage::Tag::ConstantMeta, + [](std::ostream& s) { + const std::vector data(7 * sizeof(uint64_t) + 2 * sizeof(uint8_t) + 1, 0); + s.write(data.data(), data.size()); + }}, + sfstp{SingleFileStorage::Tag::WeightSource, + [](std::ostream& s) { + s.put('c'); + }}, + sfstp{SingleFileStorage::Tag::WeightSource, [](std::ostream& s) { + const SingleFileStorage::DataIdType device_id = 9; + const SingleFileStorage::DataIdType source_id = 10; + const SingleFileStorage::PadSizeType padding_size = 11; + s.write(reinterpret_cast(&device_id), sizeof(device_id)); + s.write(reinterpret_cast(&source_id), sizeof(source_id)); + s.write(reinterpret_cast(&padding_size), sizeof(padding_size)); + const std::vector padding(padding_size - 1, 0); + s.write(padding.data(), padding.size()); + }}})); } // namespace ov::test