diff --git a/litert/runtime/from_tflite/BUILD b/litert/runtime/from_tflite/BUILD new file mode 100644 index 0000000000..ee7b049fae --- /dev/null +++ b/litert/runtime/from_tflite/BUILD @@ -0,0 +1,35 @@ +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//third_party/odml:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "allocation", + srcs = ["allocation.cc"], + hdrs = ["allocation.h"], + deps = [ + ":error_reporter", + ], +) + +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], +) diff --git a/litert/runtime/from_tflite/allocation.cc b/litert/runtime/from_tflite/allocation.cc new file mode 100644 index 0000000000..8a2799d57b --- /dev/null +++ b/litert/runtime/from_tflite/allocation.cc @@ -0,0 +1,163 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "litert/runtime/from_tflite/allocation.h" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "litert/runtime/from_tflite/error_reporter.h" + +namespace tflite { + +#ifndef TFLITE_MCU +FileCopyAllocation::FileCopyAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter, Allocation::Type::kFileCopy) { + // Obtain the file size using fstat, or report an error if that fails. + std::unique_ptr file(fopen(filename, "rb"), fclose); + if (!file) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + +// support usage of msvc's posix-like fileno symbol +#ifdef _WIN32 +#define FILENO(_x) _fileno(_x) +#else +#define FILENO(_x) fileno(_x) +#endif + if (fstat(FILENO(file.get()), &sb) != 0) { + error_reporter_->Report("Failed to get file size of '%s'.", filename); + return; + } +#undef FILENO + buffer_size_bytes_ = sb.st_size; + std::unique_ptr buffer(new char[buffer_size_bytes_]); + if (!buffer) { + error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.", + filename); + return; + } + size_t bytes_read = + fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get()); + if (bytes_read != buffer_size_bytes_) { + error_reporter_->Report("Read of '%s' failed (too few bytes read).", + filename); + return; + } + // Versions of GCC before 6.2.0 don't support std::move from non-const + // char[] to const char[] unique_ptrs. + copied_buffer_.reset(const_cast(buffer.release())); +} + +FileCopyAllocation::~FileCopyAllocation() = default; + +const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } + +size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } + +bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } +#endif + +MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : Allocation(error_reporter, Allocation::Type::kMemory) { +#ifdef __arm__ + if ((reinterpret_cast(ptr) & 0x3) != 0) { + // The flatbuffer schema has alignment requirements of up to 16 bytes to + // guarantee that data can be correctly accesses by various backends. + // Therefore, model pointer should also be 16-bytes aligned to preserve this + // requirement. But this condition only checks 4-bytes alignment which is + // the mininum requirement to prevent SIGBUS fault on 32bit ARM. Some models + // could require 8 or 16 bytes alignment which is not checked yet. + // + // Note that 64-bit ARM may also suffer a performance impact, but no crash - + // that case is not checked. + TF_LITE_REPORT_ERROR(error_reporter, + "The supplied buffer is not 4-bytes aligned"); + buffer_ = nullptr; + buffer_size_bytes_ = 0; + return; + } +#endif // __arm__ + +// `android_local_test` doesn't support zipalign b/356640509 so we need this +// workaround to keep our clients working. +// TODO: b/356413060 - Remove the workaround once b/356640509 is fixed. +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + if ((reinterpret_cast(ptr) & 0x3) != 0) { +#if defined(_WIN32) + // Windows / MSVC + aligned_ptr_ = _aligned_malloc(num_bytes, 4); +#elif defined(__ANDROID__) && __ANDROID_API__ < 28 + // Older Android (API < 28) + if (posix_memalign(&aligned_ptr_, 4, num_bytes) != 0) { + aligned_ptr_ = nullptr; + } + +#elif defined(__APPLE__) + // macOS/iOS: aligned_alloc is technically 10.15+, + // posix_memalign is safer for backwards compatibility. + if (posix_memalign(&aligned_ptr_, 4, num_bytes) != 0) { + aligned_ptr_ = nullptr; + } + +#else + // Standard C11 (Modern Linux, Android API 28+) + aligned_ptr_ = ::aligned_alloc(4, num_bytes); +#endif + + if (aligned_ptr_ == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter, "Failed to allocate aligned buffer"); + buffer_ = nullptr; + buffer_size_bytes_ = 0; + return; + } + memcpy(aligned_ptr_, ptr, num_bytes); + buffer_ = aligned_ptr_; + } else { + buffer_ = ptr; + } +#else // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + buffer_ = ptr; +#endif // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + + buffer_size_bytes_ = num_bytes; +} + +MemoryAllocation::~MemoryAllocation() { +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + if (aligned_ptr_) { + free(aligned_ptr_); + } +#endif +} + +const void* MemoryAllocation::base() const { return buffer_; } + +size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; } + +bool MemoryAllocation::valid() const { return buffer_ != nullptr; } + +} // namespace tflite diff --git a/litert/runtime/from_tflite/allocation.h b/litert/runtime/from_tflite/allocation.h new file mode 100644 index 0000000000..55cb527a3d --- /dev/null +++ b/litert/runtime/from_tflite/allocation.h @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Memory management for TF Lite. +#ifndef ODML_LITERT_LITERT_RUNTIME_FROM_TFLITE_ALLOCATION_H_ +#define ODML_LITERT_LITERT_RUNTIME_FROM_TFLITE_ALLOCATION_H_ + +#include + +#include +#include +#include + +#include "litert/runtime/from_tflite/error_reporter.h" + +namespace tflite { + +/// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + using Ptr = std::unique_ptr; + + virtual ~Allocation() = default; + + enum class Type { + kMMap, + kFileCopy, + kMemory, + }; + + /// Base pointer of this allocation + virtual const void* base() const = 0; + /// Size in bytes of the allocation + virtual size_t bytes() const = 0; + /// Whether the allocation is valid + virtual bool valid() const = 0; + /// Return the type of the Allocation. + Type type() const { return type_; } + + protected: + Allocation(ErrorReporter* error_reporter, Type type) + : error_reporter_(error_reporter), type_(type) {} + ErrorReporter* error_reporter_; + + private: + const Type type_; +}; + +/// Note that not all platforms support MMAP-based allocation. +/// Use `IsSupported()` to check. +class MMAPAllocation : public Allocation { + public: + /// Loads and maps the provided file to a memory region. + /// If map_private is true, the mapping is private and writeable. Otherwise, + /// the mapping is shared and read-only. + MMAPAllocation(const char* filename, ErrorReporter* error_reporter, + bool map_private = false); + + /// Loads and maps the provided file to a memory region at the given + /// offset and length (both in bytes). + /// If map_private is true, the mapping is private and writeable. Otherwise, + /// the mapping is shared and read-only. + MMAPAllocation(const char* filename, size_t offset, size_t length, + ErrorReporter* error_reporter, bool map_private = false); + + /// Maps the provided file descriptor to a memory region. + /// If map_private is true, the mapping is private and writeable. Otherwise, + /// the mapping is shared and read-only. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, ErrorReporter* error_reporter, + bool map_private = false); + + /// Maps the provided file descriptor, with the given offset and length (both + /// in bytes), to a memory region. + /// If map_private is true, the mapping is private and writeable. Otherwise, + /// the mapping is shared and read-only. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, size_t offset, size_t length, + ErrorReporter* error_reporter, bool map_private = false); + + ~MMAPAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + int fd() const { return mmap_fd_; } + + // The start address of the mmapped buffer. + // This will be base() rounded down to the nearest page boundary. + const void* mmapped_buffer() const { return mmapped_buffer_; } + + // The size of the mmapped buffer. + size_t mmapped_buffer_size() const { return bytes() + offset_in_buffer_; } + + // Offset of mmapped_buffer() in the file referenced by the file descriptor. + size_t mmapped_buffer_offset_in_file() const { + return offset_of_buffer_in_file_; + } + + static bool IsSupported(); + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; + // Used when the address to mmap is not page-aligned. + size_t offset_in_buffer_ = 0; + size_t offset_of_buffer_in_file_ = 0; + + private: + // Assumes ownership of the provided `owned_fd` instance. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, bool map_private); + + // Assumes ownership of the provided `owned_fd` instance, and uses the given + // offset and length (both in bytes) for memory mapping. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, + size_t length, bool map_private); +}; + +class FileCopyAllocation : public Allocation { + public: + /// Loads the provided file into a heap memory region. + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + ~FileCopyAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + /// Provides a (read-only) view of the provided buffer region as an + /// allocation. + /// Note: The caller retains ownership of `ptr`, and must ensure it remains + /// valid for the lifetime of the class instance. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + ~MemoryAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + void* aligned_ptr_ = nullptr; +#endif + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // ODML_LITERT_LITERT_RUNTIME_FROM_TFLITE_ALLOCATION_H_ diff --git a/litert/runtime/from_tflite/error_reporter.cc b/litert/runtime/from_tflite/error_reporter.cc new file mode 100644 index 0000000000..5b9c02338e --- /dev/null +++ b/litert/runtime/from_tflite/error_reporter.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "litert/runtime/from_tflite/error_reporter.h" + +#include + +namespace tflite { + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +} // namespace tflite diff --git a/litert/runtime/from_tflite/error_reporter.h b/litert/runtime/from_tflite/error_reporter.h new file mode 100644 index 0000000000..be4d775003 --- /dev/null +++ b/litert/runtime/from_tflite/error_reporter.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef ODML_LITERT_LITERT_RUNTIME_FROM_TFLITE_ERROR_REPORTER_H_ +#define ODML_LITERT_LITERT_RUNTIME_FROM_TFLITE_ERROR_REPORTER_H_ + +#include + +namespace tflite { + +/// A functor that reports error to supporting system. Invoked similar to +/// printf. +/// +/// Usage: +/// ErrorReporter foo; +/// foo.Report("test %d", 5); +/// or +/// va_list args; +/// foo.Report("test %d", args); // where args is va_list +/// +/// Subclass ErrorReporter to provide another reporting destination. +/// For example, if you have a GUI program, you might redirect to a buffer +/// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter() = default; + /// Converts `args` to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + virtual int Report(const char* format, va_list args) = 0; + + /// Converts arguments to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + int Report(const char* format, ...); + + /// Equivalent to `Report` above. The additional `void*` parameter is unused. + /// This method is for compatibility with macros that takes `TfLiteContext`, + /// like TF_LITE_ENSURE and related macros. + int ReportError(void*, const char* format, ...); +}; + +} // namespace tflite + +// You should not make bare calls to the error reporter, instead use the +// TF_LITE_REPORT_ERROR macro, since this allows message strings to be +// stripped when the binary size has to be optimized. If you are looking to +// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and +// every call will be stubbed out, taking no memory. +#ifndef TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) \ + do { \ + static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \ + } while (false) +#else // TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) +#endif // TF_LITE_STRIP_ERROR_STRINGS + +#endif // ODML_LITERT_LITERT_RUNTIME_FROM_TFLITE_ERROR_REPORTER_H_