Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ class SingleFileStorage final : public ICacheManager, public IContextStore {
std::filesystem::path m_file_path;

struct BlobInfo {
std::streampos offset;
std::streamoff size;
uint64_t offset;
uint64_t size;
Comment thread
t-jankowski marked this conversation as resolved.
std::string model_name;
};
std::unordered_map<BlobIdType, BlobInfo> m_blob_index;
std::shared_ptr<wsh::Context> m_shared_context;
bool build_content_index(std::ifstream& stream);

static BlobIdType convert_blob_id(const std::string& blob_id);
void write_blob_entry(std::ofstream& stream, BlobIdType blob_id, StreamWriter& writer);
void write_blob_entry(std::fstream& stream, BlobIdType blob_id, StreamWriter& writer);
Comment thread
praasz marked this conversation as resolved.
bool has_blob_id(BlobIdType blob_id) const;
};
} // namespace ov::runtime
19 changes: 15 additions & 4 deletions src/inference/src/dev/tlv_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "openvino/runtime/tlv_format.hpp"

#include <limits>
#include <type_traits>
#include <utility>

Expand Down Expand Up @@ -48,6 +49,19 @@ static bool read_record(std::istream& stream, TLVTraits::TagType& tag, TLVTraits
data.clear();
return true;
}
constexpr auto size_limit =
std::min(static_cast<TLVTraits::LengthType>(std::numeric_limits<std::streamsize>::max()),
static_cast<TLVTraits::LengthType>(std::numeric_limits<size_t>::max()));
if (size > size_limit) {
return false;
}
const auto current_pos = stream.tellg();
const auto stream_end = stream.seekg(0, std::ios::end).tellg();
stream.seekg(current_pos);
const auto remaining_offset = stream.good() ? static_cast<TLVTraits::LengthType>(stream_end - current_pos) : 0u;
if (remaining_offset < size) {
return false;
}
Comment thread
t-jankowski marked this conversation as resolved.
data.resize(size);
stream.read(reinterpret_cast<char*>(data.data()), size);
return stream.good();
Comment thread
t-jankowski marked this conversation as resolved.
Expand Down Expand Up @@ -75,7 +89,7 @@ bool scan_tlv_records(std::istream& stream, const TLVValueScanner& scanners) {
return false;
}
stream.read(reinterpret_cast<char*>(&size), sizeof(size));
if (!stream.good()) {
if (!stream.good() || (stream_end - stream.tellg() < static_cast<std::streamoff>(size))) {
return false;
}

Expand All @@ -84,9 +98,6 @@ bool scan_tlv_records(std::istream& stream, const TLVValueScanner& scanners) {
return false;
}
} else {
if (stream_end - stream.tellg() < static_cast<std::streamoff>(size) || !stream.good()) {
return false;
}
stream.seekg(size, std::ios::cur);
if (!stream.good()) {
return false;
Expand Down
120 changes: 60 additions & 60 deletions src/inference/src/single_file_storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ bool read_tlv_string(std::istream& stream, std::string& str) {
TLVTraits::TagType tag;
TLVTraits::LengthType size;
std::vector<char> buffer;
const auto read = read_tlv_record(stream, tag, size, buffer);
if (read) {
OPENVINO_ASSERT(SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String);
if (read_tlv_record(stream, tag, size, buffer) && SingleFileStorage::Tag{tag} == SingleFileStorage::Tag::String) {
str = std::string{buffer.begin(), buffer.end()};
return true;
} else {
return false;
}
return read;
}

void write_padding(std::ostream& stream, uint64_t alignment) {
Expand Down Expand Up @@ -89,115 +89,114 @@ SingleFileStorage::SingleFileStorage(const std::filesystem::path& path)
}

bool SingleFileStorage::build_content_index(std::ifstream& stream) {
const auto current_pos = stream.tellg();
const auto end_pos = stream.seekg(0, std::ios::end).tellg();
stream.seekg(current_pos);
const auto blob_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) {
const auto blob_reader = [this](std::istream& s, TLVTraits::LengthType size) {
if (size == 0) {
return true;
}
constexpr auto header_size = sizeof(BlobIdType) + sizeof(PadSizeType);
if (size < header_size) {
Comment thread
t-jankowski marked this conversation as resolved.
return false;
}
BlobIdType id;
PadSizeType padding_size;
s.read(reinterpret_cast<char*>(&id), sizeof(id));
s.read(reinterpret_cast<char*>(&padding_size), sizeof(padding_size));
if (s.tellg() + static_cast<std::streamoff>(padding_size) > end_pos) {
if (!s.good() || padding_size > size - header_size) {
return false;
}
s.seekg(padding_size, std::ios::cur);
if (!s.good()) {
const auto blob_data_pos = s.seekg(padding_size, std::ios::cur).tellg();
if (!s.good() || blob_data_pos < 0) {
return false;
}
const auto blob_data_pos = s.tellg();
const auto blob_data_size =
static_cast<std::streamoff>(size - sizeof(id) - sizeof(padding_size) - padding_size);
m_blob_index[id].offset = blob_data_pos;
m_blob_index[id].size = blob_data_size;
if (blob_data_pos + blob_data_size > end_pos) {
const auto blob_data_size = static_cast<std::streamoff>(size - header_size - padding_size);
s.seekg(blob_data_size, std::ios::cur);
if (!s.good()) {
return false;
}
s.seekg(blob_data_size, std::ios::cur);
return s.good();
m_blob_index[id].offset = static_cast<uint64_t>(blob_data_pos);
m_blob_index[id].size = static_cast<uint64_t>(blob_data_size);
return true;
};
const auto blob_map_reader = [this](std::istream& s, TLVTraits::LengthType size) {
if (size == 0) {
return true;
}
if (size < sizeof(BlobIdType) + sizeof(TLVTraits::TagType) + sizeof(TLVTraits::LengthType)) {
return false;
}
BlobIdType id;
s.read(reinterpret_cast<char*>(&id), sizeof(id));
if (!s.good()) {
return false;
}
if (std::string model_name; read_tlv_string(s, model_name)) {
m_blob_index[id].model_name = std::move(model_name);
return true;
} else {
return false;
}
return s.good();
};
const auto constant_meta_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) {
const auto constant_meta_reader = [this](std::istream& s, TLVTraits::LengthType size) {
if (size == 0) {
return true;
}
if (s.tellg() + static_cast<std::streamoff>(size) > end_pos) {
uint64_t source_id;
if (size < sizeof(source_id)) {
return false;
}
uint64_t source_id;
s.read(reinterpret_cast<char*>(&source_id), sizeof(source_id));
auto left_size = size - sizeof(source_id);

while (s.good() && left_size > 0) {
auto remaining_size = size - sizeof(source_id);
while (s.good() && remaining_size > 0) {
uint64_t const_id, const_offset, const_size;
uint8_t const_type;
constexpr auto const_meta_size =
sizeof(const_id) + sizeof(const_offset) + sizeof(const_size) + sizeof(const_type);
if (left_size < const_meta_size) {
if (remaining_size < const_meta_size) {
return false;
}
left_size -= const_meta_size;
remaining_size -= const_meta_size;
s.read(reinterpret_cast<char*>(&const_id), sizeof(const_id));
s.read(reinterpret_cast<char*>(&const_offset), sizeof(const_offset));
s.read(reinterpret_cast<char*>(&const_size), sizeof(const_size));
s.read(reinterpret_cast<char*>(&const_type), sizeof(const_type));
if (!s.good()) {
return false;
if (s.good()) {
m_shared_context->m_weight_registry[source_id][const_id] = {static_cast<size_t>(const_offset),
static_cast<size_t>(const_size),
element::Type_t{const_type}};
}
m_shared_context->m_weight_registry[source_id][const_id] = {static_cast<size_t>(const_offset),
static_cast<size_t>(const_size),
element::Type_t{const_type}};
}
return s.good();
};
const auto constant_source_reader = [this, end_pos](std::istream& s, TLVTraits::LengthType size) {
const auto weight_source_reader = [this](std::istream& s, TLVTraits::LengthType size) {
if (size == 0) {
return true;
}
DataIdType device_id, source_id;
PadSizeType padding_size;
if (size < sizeof(device_id) + sizeof(source_id) + sizeof(padding_size) ||
s.tellg() + static_cast<std::streamoff>(size) > end_pos) {
constexpr auto header_size = sizeof(DataIdType) + sizeof(DataIdType) + sizeof(PadSizeType);
if (size < header_size) {
return false;
}
DataIdType device_id, source_id;
PadSizeType padding_size;
s.read(reinterpret_cast<char*>(&device_id), sizeof(device_id));
s.read(reinterpret_cast<char*>(&source_id), sizeof(source_id));
s.read(reinterpret_cast<char*>(&padding_size), sizeof(padding_size));
if (s.tellg() + static_cast<std::streamoff>(padding_size) > end_pos) {
if (!s.good() || padding_size > size - header_size) {
return false;
}
s.seekg(padding_size, std::ios::cur);
if (!s.good()) {
return false;
}
const auto weight_size = size - sizeof(device_id) - sizeof(source_id) - sizeof(padding_size) - padding_size;
const auto weight_size = size - header_size - padding_size;
m_shared_context->m_cache_sources[source_id] = {};
if (s.tellg() + static_cast<std::streamoff>(weight_size) > end_pos) {
return false;
}
s.seekg(weight_size, std::ios::cur);
return s.good();
};
const TLVValueScanner scanners = {
{static_cast<TLVTraits::TagType>(Tag::Blob), blob_reader},
{static_cast<TLVTraits::TagType>(Tag::BlobMap), blob_map_reader},
{static_cast<TLVTraits::TagType>(Tag::ConstantMeta), constant_meta_reader},
{static_cast<TLVTraits::TagType>(Tag::WeightSource), constant_source_reader},
{static_cast<TLVTraits::TagType>(Tag::WeightSource), weight_source_reader},
};
return scan_tlv_records(stream, scanners);
}
Expand All @@ -210,7 +209,7 @@ bool SingleFileStorage::has_blob_id(BlobIdType blob_id) const {
return m_blob_index.find(blob_id) != m_blob_index.end();
}

void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_id, StreamWriter& writer) {
void SingleFileStorage::write_blob_entry(std::fstream& stream, BlobIdType blob_id, StreamWriter& writer) {
OPENVINO_ASSERT(!has_blob_id(blob_id), "Blob with id ", blob_id, " already exists in cache.");

std::streampos blob_pos;
Expand All @@ -220,8 +219,10 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_
s.write(reinterpret_cast<const char*>(&blob_id), sizeof(blob_id));
write_padding(s, blob_alignment);
blob_pos = s.tellp();
OPENVINO_ASSERT(blob_pos >= 0, "Invalid blob data position ", blob_pos, " for blob id ", blob_id);
writer(s);
blob_size = s.tellp() - blob_pos;
OPENVINO_ASSERT(blob_size >= 0, "Invalid blob size ", blob_size, " for blob id ", blob_id);
Comment thread
t-jankowski marked this conversation as resolved.
};
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::Blob), blob_writer);

Expand All @@ -233,12 +234,13 @@ void SingleFileStorage::write_blob_entry(std::ofstream& stream, BlobIdType blob_
};
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::BlobMap), blob_map_writer);

m_blob_index[blob_id] = {blob_pos, blob_size, std::move(model_name)};
m_blob_index[blob_id] = {static_cast<uint64_t>(blob_pos), static_cast<uint64_t>(blob_size), std::move(model_name)};
}
Comment thread
t-jankowski marked this conversation as resolved.

void SingleFileStorage::write_cache_entry(const std::string& blob_id, StreamWriter writer) {
ScopedLocale plocal_C(LC_ALL, "C");
std::ofstream stream(m_file_path, std::ios::binary | std::ios::in | std::ios::ate);
std::fstream stream(m_file_path, std::ios::binary | std::ios::in | std::ios::out | std::ios::ate);
OPENVINO_ASSERT(stream.good(), "Failed to open cache file ", m_file_path, " for writing blob id ", blob_id);
Comment thread
t-jankowski marked this conversation as resolved.
write_blob_entry(stream, convert_blob_id(blob_id), writer);
}

Expand All @@ -250,12 +252,11 @@ void SingleFileStorage::read_cache_entry(const std::string& blob_id, bool enable
if (std::filesystem::exists(m_file_path) && has_blob_id(cid)) {
const auto& [blob_pos, blob_size, model_name] = m_blob_index[cid];
if (enable_mmap) {
// CVS-181859 Extend memory mapping helpers to suport partial file mapping
CompiledBlobVariant compiled_blob{std::in_place_index<0>,
ov::read_tensor_data(m_file_path,
element::u8,
{static_cast<PartialShape::value_type>(blob_size)},
blob_pos)};
read_tensor_data(m_file_path,
element::u8,
{static_cast<PartialShape::value_type>(blob_size)},
blob_pos)};
reader(compiled_blob);
Comment thread
t-jankowski marked this conversation as resolved.
} else {
std::ifstream stream(m_file_path, std::ios::binary);
Expand Down Expand Up @@ -312,19 +313,18 @@ void SingleFileStorage::write_context(const weight_sharing::Context& context) {
}
}
for (const auto& cache_registry : delta_cache_sources) {
const auto weight_source_writer = [&](std::ostream& s) {
const auto& [source_id, weight_buffer] = cache_registry;
if (auto buf = weight_buffer.m_weights.lock()) {
const auto& source_id = cache_registry.first;
const auto& weight_buffer = cache_registry.second;
if (auto weights = weight_buffer.m_weights.lock()) {
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::WeightSource), [&](std::ostream& s) {
const auto device_id = static_cast<uint64_t>(std::strtoul(weight_buffer.m_device.c_str(), nullptr, 10));
s.write(reinterpret_cast<const char*>(&device_id), sizeof(device_id));
s.write(reinterpret_cast<const char*>(&source_id), sizeof(source_id));
write_padding(s, blob_alignment);
s.write(reinterpret_cast<const char*>(buf->get_ptr()), buf->size());
}

s.write(weights->get_ptr<char>(), weights->size());
});
m_shared_context->m_cache_sources[source_id] = weight_buffer;
};
write_tlv_record(stream, static_cast<TLVTraits::TagType>(Tag::WeightSource), weight_source_writer);
}
}

for (const auto& [source_id, buffer] : context.m_runtime_sources) {
Expand Down
Loading
Loading