diff --git a/python/tflite_micro/BUILD b/python/tflite_micro/BUILD index b358fd12adc..812cf7092fd 100644 --- a/python/tflite_micro/BUILD +++ b/python/tflite_micro/BUILD @@ -125,7 +125,7 @@ py_test( ":runtime", requirement("numpy"), requirement("tensorflow"), - "//tensorflow/lite/micro/compression", + "//tensorflow/lite/micro/compression:model_editor", ], ) diff --git a/python/tflite_micro/_runtime.cc b/python/tflite_micro/_runtime.cc index 246545fd016..53825f14f0d 100644 --- a/python/tflite_micro/_runtime.cc +++ b/python/tflite_micro/_runtime.cc @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) { .def(py::init([](const py::bytes& data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - tflite::InterpreterConfig config) { - return std::unique_ptr( - new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size, - num_resource_variables, config)); + tflite::InterpreterConfig config, + size_t alt_decompression_memory_size) { + return std::unique_ptr(new InterpreterWrapper( + data.ptr(), registerers_by_name, arena_size, num_resource_variables, + config, alt_decompression_memory_size)); })) .def("PrintAllocations", &InterpreterWrapper::PrintAllocations) .def("Invoke", &InterpreterWrapper::Invoke) diff --git a/python/tflite_micro/interpreter_wrapper.cc b/python/tflite_micro/interpreter_wrapper.cc index 669589890ad..c74ab84736b 100644 --- a/python/tflite_micro/interpreter_wrapper.cc +++ b/python/tflite_micro/interpreter_wrapper.cc @@ -238,7 +238,14 @@ InterpreterWrapper::~InterpreterWrapper() { InterpreterWrapper::InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, - size_t arena_size, int num_resource_variables, InterpreterConfig config) { + size_t arena_size, int num_resource_variables, InterpreterConfig config, + size_t alt_decompression_memory_size) + : memory_arena_(new uint8_t[arena_size]), + alt_decompression_memory_(alt_decompression_memory_size > 0 + ? new uint8_t[alt_decompression_memory_size] + : nullptr), + alt_decompression_region_{alt_decompression_memory_.get(), + alt_decompression_memory_size} { interpreter_ = nullptr; // `model_data` is used as a raw pointer beyond the scope of this @@ -266,7 +273,6 @@ InterpreterWrapper::InterpreterWrapper( "--//:with_compression=true to enable compression support."); } - memory_arena_ = std::unique_ptr(new uint8_t[arena_size]); for (const std::string& registerer : registerers_by_name) { if (!AddCustomOpRegistererByName(registerer.c_str(), &python_ops_resolver_)) { @@ -296,6 +302,14 @@ InterpreterWrapper::InterpreterWrapper( interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_, resource_variables_); + if (alt_decompression_memory_size > 0) { + TfLiteStatus status = + interpreter_->SetDecompressionMemory(&alt_decompression_region_, 1); + if (status != kTfLiteOk) { + ThrowRuntimeError("TFLM failed to set decompression memory"); + } + } + TfLiteStatus status = interpreter_->AllocateTensors(); if (status != kTfLiteOk) { ThrowRuntimeError("TFLM failed to allocate tensors"); diff --git a/python/tflite_micro/interpreter_wrapper.h b/python/tflite_micro/interpreter_wrapper.h index 9bb31b067fe..d3a156b337a 100644 --- a/python/tflite_micro/interpreter_wrapper.h +++ b/python/tflite_micro/interpreter_wrapper.h @@ -19,6 +19,7 @@ limitations under the License. #include "python/tflite_micro/python_ops_resolver.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_context.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/recording_micro_allocator.h" @@ -40,7 +41,8 @@ class InterpreterWrapper { InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - InterpreterConfig config = InterpreterConfig::kAllocationRecording); + InterpreterConfig config = InterpreterConfig::kAllocationRecording, + size_t alt_decompression_memory_size = 0); ~InterpreterWrapper(); void PrintAllocations(); @@ -57,6 +59,8 @@ class InterpreterWrapper { tflite::RecordingMicroAllocator* recording_allocator_ = nullptr; const PyObject* model_; std::unique_ptr memory_arena_; + std::unique_ptr alt_decompression_memory_; + tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_; tflite::PythonOpsResolver python_ops_resolver_; tflite::MicroInterpreter* interpreter_; }; diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index d895f8c4993..7052972b4a6 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -100,6 +100,7 @@ def __init__( custom_op_registerers, arena_size, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): if model_data is None: raise ValueError("Model must not be None") @@ -122,6 +123,7 @@ def __init__( arena_size, num_resource_variables, _ENUM_TRANSLATOR[intrepreter_config], + alt_decompression_memory_size, ) @classmethod @@ -131,6 +133,7 @@ def from_file( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model .tflite filepath. @@ -140,6 +143,9 @@ def from_file( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -155,6 +161,7 @@ def from_file( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) @classmethod @@ -164,6 +171,7 @@ def from_bytes( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model in byte array. @@ -173,6 +181,9 @@ def from_bytes( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -183,6 +194,7 @@ def from_bytes( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) def print_allocations(self): diff --git a/python/tflite_micro/test_compression_unsupported.py b/python/tflite_micro/test_compression_unsupported.py index 3692dd0a43a..01c598374ce 100644 --- a/python/tflite_micro/test_compression_unsupported.py +++ b/python/tflite_micro/test_compression_unsupported.py @@ -12,84 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test compression metadata detection when compression is disabled.""" +"""Test legacy compression metadata detection when compression is disabled.""" import os import numpy as np import tensorflow as tf from tflite_micro.python.tflite_micro import runtime -from tflite_micro.tensorflow.lite.micro import compression +from tflite_micro.tensorflow.lite.micro.compression import model_editor -class CompressionDetectionTest(tf.test.TestCase): - """Test compression metadata detection when compression is disabled.""" +def _create_test_model(): + """Create a simple quantized model for testing.""" + model = tf.keras.Sequential([ + tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), + tf.keras.layers.Dense(5, activation='softmax') + ]) + model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') - def _create_test_model(self): - """Create a simple quantized model for testing.""" - model = tf.keras.Sequential([ - tf.keras.layers.Dense(10, input_shape=(5, ), activation='relu'), - tf.keras.layers.Dense(5, activation='softmax') - ]) - model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] - # Convert to quantized TFLite - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] + def representative_dataset(): + for _ in range(10): + yield [np.random.randn(1, 5).astype(np.float32)] - def representative_dataset(): - for _ in range(10): - yield [np.random.randn(1, 5).astype(np.float32)] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 - converter.representative_dataset = representative_dataset - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + return bytes(tflite_model) if isinstance(tflite_model, + bytearray) else tflite_model - tflite_model = converter.convert() - return bytes(tflite_model) if isinstance(tflite_model, - bytearray) else tflite_model + +def _inject_compression_metadata(model_data): + """Inject raw COMPRESSION_METADATA into a model's flatbuffer metadata. + + This simulates a legacy-compressed model (one that uses the + COMPRESSION_METADATA metadata entry and kernel-level decompression) without + going through compress(), which now produces DECODE-based output. + """ + model = model_editor.read(model_data) + model.metadata["COMPRESSION_METADATA"] = b"\x00" + return bytes(model.build()) + + +class LegacyCompressionDetectionTest(tf.test.TestCase): + """Test that legacy COMPRESSION_METADATA is rejected without the flag.""" def test_regular_model_loads_successfully(self): """Non-compressed models should load without issues.""" - model_data = self._create_test_model() + model_data = _create_test_model() interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) - def test_compressed_model_raises_runtime_error(self): - """Compressed models should raise RuntimeError when compression is disabled.""" - # Create and compress a model - model_data = self._create_test_model() + def test_legacy_compressed_model_raises_runtime_error(self): + """Models with COMPRESSION_METADATA should raise RuntimeError.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - - compressed_model = compression.compress(model_data, spec) - if isinstance(compressed_model, bytearray): - compressed_model = bytes(compressed_model) - - # Should raise RuntimeError with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(compressed_model) - - def test_can_load_regular_after_compressed_failure(self): - """Verify we can still load regular models after compressed model fails.""" - model_data = self._create_test_model() + runtime.Interpreter.from_bytes(legacy_model) - # First try compressed model (should fail) - spec = (compression.SpecBuilder().add_tensor( - subgraph=0, tensor=1).with_lut(index_bitwidth=4).build()) - compressed_model = compression.compress(model_data, spec) + def test_can_load_regular_after_legacy_failure(self): + """Verify regular models still load after a legacy-compressed failure.""" + model_data = _create_test_model() + legacy_model = _inject_compression_metadata(model_data) with self.assertRaises(RuntimeError): - runtime.Interpreter.from_bytes(bytes(compressed_model)) + runtime.Interpreter.from_bytes(legacy_model) - # Then load regular model (should succeed) interpreter = runtime.Interpreter.from_bytes(model_data) self.assertIsNotNone(interpreter) if __name__ == '__main__': - # Set TF environment variables to suppress warnings + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + # Disable oneDNN to avoid non-deterministic floating point results os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' tf.test.main() diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 36725fac63c..a9fe9fa36de 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -123,14 +123,15 @@ py_library( "compress.py", ], deps = [ - ":metadata_py", + ":compressor", + ":decode_insert", + ":huffman", + ":lut", ":model_editor", + ":pruning", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), - requirement("flatbuffers"), - requirement("bitarray"), - requirement("numpy"), ], ) @@ -159,33 +160,54 @@ py_test( target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ ":compress", - ":metadata_py", + ":compressor", + ":decode_insert", ":model_editor", ":spec", - ":test_models", "//tensorflow/lite/python:schema_py", - requirement("bitarray"), requirement("numpy"), ], ) -tflm_py_library( - name = "model_facade", - srcs = ["model_facade.py"], +tflm_py_test( + name = "compression_integration_test", + size = "small", + srcs = ["compression_integration_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ + ":compress_lib", + ":decode_insert", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), + requirement("numpy"), ], ) -py_test( - name = "model_facade_test", +tflm_py_test( + name = "proprietary_integration_test", size = "small", - srcs = ["model_facade_test.py"], + srcs = ["proprietary_integration_test.py"], + tags = [ + "manual", + "noasan", + "nomsan", + "noubsan", + ], target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, deps = [ - ":model_facade", - ":test_models", + ":compress_lib", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), ], ) @@ -227,62 +249,150 @@ py_test( ) tflm_py_library( - name = "test_models", - srcs = ["test_models.py"], + name = "tensor_type", + srcs = ["tensor_type.py"], deps = [ "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), requirement("numpy"), ], ) -py_test( - name = "test_models_test", +tflm_py_test( + name = "tensor_type_test", size = "small", - srcs = ["test_models_test.py"], - target_compatible_with = INCOMPATIBLE_WITH_WINDOWS, + srcs = ["tensor_type_test.py"], deps = [ - ":test_models", + ":tensor_type", "//tensorflow/lite/python:schema_py", + requirement("numpy"), ], ) tflm_py_library( - name = "tensor_type", - srcs = ["tensor_type.py"], + name = "model_editor", + srcs = ["model_editor.py"], deps = [ + ":tensor_type", "//tensorflow/lite/python:schema_py", + requirement("flatbuffers"), requirement("numpy"), ], ) tflm_py_test( - name = "tensor_type_test", + name = "model_editor_test", size = "small", - srcs = ["tensor_type_test.py"], + srcs = ["model_editor_test.py"], deps = [ - ":tensor_type", + ":model_editor", "//tensorflow/lite/python:schema_py", requirement("numpy"), ], ) tflm_py_library( - name = "model_editor", - srcs = ["model_editor.py"], + name = "decode", + srcs = ["decode.py"], +) + +tflm_py_test( + name = "decode_test", + size = "small", + srcs = ["decode_test.py"], deps = [ - ":tensor_type", + ":decode", + ], +) + +tflm_py_library( + name = "compressor", + srcs = ["compressor.py"], + deps = [ + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "lut", + srcs = ["lut.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + requirement("bitarray"), + requirement("numpy"), + ], +) + +tflm_py_test( + name = "lut_test", + size = "small", + srcs = ["lut_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":lut", + ":model_editor", + ":spec", "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), requirement("numpy"), ], ) +tflm_py_library( + name = "huffman", + srcs = ["huffman.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "pruning", + srcs = ["pruning.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +tflm_py_library( + name = "decode_insert", + srcs = ["decode_insert.py"], + deps = [ + ":compressor", + ":model_editor", + "//tensorflow/lite/python:schema_py", + ], +) + tflm_py_test( - name = "model_editor_test", + name = "decode_insert_test", size = "small", - srcs = ["model_editor_test.py"], + srcs = ["decode_insert_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], deps = [ + ":compressor", + ":decode", + ":decode_insert", + ":lut", ":model_editor", "//tensorflow/lite/python:schema_py", requirement("numpy"), diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index b6d5aef4435..96b55d94fd7 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -16,22 +16,22 @@ See USAGE. """ -import bitarray -import bitarray.util -from dataclasses import dataclass, field import os import sys import tempfile -from typing import ByteString, Iterable, Optional +import warnings +from typing import ByteString, Iterable, Type import absl.app import absl.flags -import flatbuffers -import numpy as np +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import huffman +from tflite_micro.tensorflow.lite.micro.compression import lut from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import pruning from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper USAGE = f"""\ @@ -49,221 +49,48 @@ {spec.EXAMPLE_YAML_SPEC} --- -The only compression method currently implemented is "lut", i.e., -Look-Up-Table. This method requires the tensor in the input model to have a -small number of unique values, fewer than or equal to 2**index_bitwidth. LUT -compression collects these values into a lookup table, and rewrites the tensor -as bitwidth-wide integer indices into that lookup table. Presumably, the input -model has been trained or preprocessed in a way that the tensor values -are binned into a meaningful, limited set. -""" - -# A compressed model augments the usual .tflite flatbuffer with a flatbuffer of -# its own containing compression metadata, stored at the buffer index stored at -# the following key in the .tflite flatbuffer's metadata map. -TFLITE_METADATA_KEY = "COMPRESSION_METADATA" - - -class CompressionError(Exception): - """Raised when compression fails for the reason documented in the message.""" - - def __init__(self, message, wrapped_exception=None): - super().__init__(f"{message}: {str(wrapped_exception)}") - self.original_exception = wrapped_exception - - -class _MetadataBuilder: - """Builder for the compression metadata flatbuffer.""" - - def __init__(self): - self._metadata = schema.MetadataT() - self._metadata.subgraphs = [] - - def compile(self) -> bytearray: - """Packs the metadata into a binary array and returns it. - """ - builder = flatbuffers.Builder(1 * 2**10) - root = self._metadata.Pack(builder) - builder.Finish(root) - return builder.Output() - - def subgraph(self, index: int): - """Return subgraph at index, adding subgraphs if necessary. - """ - while len(self._metadata.subgraphs) <= index: - self._add_subgraph() - return self._metadata.subgraphs[index] - - def add_lut_tensor(self, subgraph_id: int): - """Add LUT tensor to the given subgraph and return it. - """ - tensor = schema.LutTensorT() - self.subgraph(subgraph_id).lutTensors.append(tensor) - return tensor - - def _add_subgraph(self): - subgraph = schema.SubgraphT() - subgraph.lutTensors = [] - self._metadata.subgraphs.append(subgraph) - return subgraph - - -@dataclass -class _LutCompressedArray: - compression_axis: Optional[int] = None - lookup_tables: list[np.ndarray] = field(default_factory=list) - indices: np.ndarray = field(default_factory=lambda: np.array([])) - - @property - def index_bitwidth(self) -> int: - """Returns the number of bits required to encode the indices.""" - if self.indices is None: - raise ValueError - - max_index = int(np.max(self.indices)) - return max_index.bit_length() or 1 - - -def _lut_compress_array(tensor: np.ndarray, - axis: Optional[int]) -> _LutCompressedArray: - """Compresses the given tensor using lookup tables. - - Args: - tensor (np.ndarray): The tensor to be compressed. - - axis (Optional[int]): The axis along which to compress the tensor. If an - axis is given, a lookup table is created for each slice along the - axis. If axis is None, a single lookup table is used for the entire - tensor. - - Compressing a tensor with a lookup table per slice along a - particular axis is analogous to quantizing a tensor with different - quantization parameters per slice along a particular axis (dimension). - - Returns: - _LutCompressedArray: An object containing the compressed tensor data, - including the lookup tables and indices. - """ - compressed = _LutCompressedArray() - compressed.compression_axis = axis - - if axis is None: - # Compute unique values and indices for the entire tensor - values, indices = np.unique(tensor, return_inverse=True) - compressed.lookup_tables.append(values) - compressed.indices = indices.reshape(tensor.shape) - else: - # Iterate over slices along the compression axis - slice_indices = [] - for slice in np.moveaxis(tensor, axis, 0): - values, indices = np.unique(slice, return_inverse=True) - compressed.lookup_tables.append(values) - indices = indices.reshape(slice.shape) - slice_indices.append(indices) +Supported compression methods: - # Reconstruct a tensor of indices from the slices - stacked = np.stack(slice_indices, axis=0) - compressed.indices = np.moveaxis(stacked, 0, axis) + lut: Look-Up-Table compression. Requires the tensor to have a small number of + unique values, fewer than or equal to 2**index_bitwidth. LUT compression + collects these values into a lookup table, and rewrites the tensor as + bitwidth-wide integer indices into that lookup table. - return compressed + huffman: Huffman compression using Xtensa-format decode tables. (Not yet + implemented.) + pruning: Pruning (sparsity) compression for sparse tensors. (Not yet + implemented.) -def _check_lut_compression(compression) -> spec.LookUpTableCompression: - if len(compression) != 1: - raise CompressionError("Each tensor must have exactly one compression") - if not isinstance(compression[0], spec.LookUpTableCompression): - raise CompressionError('Only "lut" compression may be specified') - - return compression[0] - - -def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: - """Determines the axis along which to compress. - - The axis along which to compress is inferred from the tensor's quantization - parameters. - - Returns: - The axis along which to compress, or None to indicate one value table for - the entire tensor. - - Raises: - CompressionError: If the axis cannot be determined. - """ - q = tensor.quantization - if q is not None: - # model_editor wraps quantization, access scales/axis from wrapper - scales = q.scales if isinstance(q.scales, list) else [q.scales] - quantization_channels = len(scales) - - if quantization_channels == 1: - # Use one value table for the entire tensor - return None - - if q.axis is not None and q.axis < len(tensor.shape): - if quantization_channels == tensor.shape[q.axis]: - return q.axis - - raise CompressionError( - f"Invalid or no quanitzation parameters from which to " - f"infer the axis along which tensor should be compressed.") - - -def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): - """Applies business logic regarding specified bitwidth. - - It is an error if the bitwidth required to compress a tensor exceeds the - specified bitwith, and a warning if the tensor can be compressed in less than - the specified bitwidth. The latter is allowed, and is not an error, to permit - testing with larger bitwidths without re-binning a model. - """ - if compressed > specified: - raise CompressionError( - f"index_bitwidth too small: {compressed} bits needed to " - f"enumerate unique values in tensor specified in {spec}") - elif compressed < specified: - print( - f"warning: index_bitwidth too large: only {compressed} " - f"bits needed to enumerate unique values in tensor specified in {spec}", - file=sys.stderr) - - -def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: - """Packs indices into a bytearray using bitwidth-sized fields. - """ - endianness = "big" - bits = bitarray.bitarray(endian=endianness) - for i in indices.ravel(): - bits.extend( - bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) - return bits.tobytes() - +Compressed models use DECODE operators to decompress tensors at runtime. +""" -def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: - """Packs the value tables of a LutCompressedArray. +# Plugin dispatch table: maps CompressionMethod subclasses to compressor instances +_COMPRESSORS: dict[Type[spec.CompressionMethod], compressor.Compressor] = { + spec.LookUpTableCompression: lut.LutCompressor(), + spec.HuffmanCompression: huffman.HuffmanCompressor(), + spec.PruningCompression: pruning.PruningCompressor(), +} - Pack the value tables of a LutCompressedArray into a bytes object in the - format writable to a value_table buffer in the .tflite flatbuffer. The - tables are concatenated. - """ - buffer = bytearray() - for t in tables: - padding_needed = table_len - len(t) - padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) - buffer.extend(padded.tobytes()) - return buffer +def _get_compressor(method: spec.CompressionMethod) -> compressor.Compressor: + """Get the compressor plugin for a given compression method.""" + compressor_instance = _COMPRESSORS.get(type(method)) + if compressor_instance is None: + raise compressor.CompressionError( + f"No compressor registered for {type(method).__name__}") + return compressor_instance def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: """Applies proper FlatBuffer alignment to a model. - + The Python flatbuffers library doesn't respect `force_align` schema attributes, so we use the C++ wrapper which properly handles alignment requirements. - + Args: model_bytes: The model flatbuffer to align - + Returns: The properly aligned model flatbuffer """ @@ -295,45 +122,63 @@ def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: """Compresses a model .tflite flatbuffer. + Compresses tensors according to the given specs and inserts DECODE operators + to decompress them at runtime. + Args: model_in: the original, uncompressed .tflite flatbuffer specs: an iterable of compression specs, see module spec.py Returns: - A compressed flatbuffer. + A compressed flatbuffer with DECODE operators inserted. """ + specs = list(specs) + if not specs: + raise compressor.CompressionError( + "Compression spec is empty; no tensors to compress") + model = model_editor.read(model_in) - metadata = _MetadataBuilder() + compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} - for spec in specs: + for tensor_spec in specs: try: - tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] - lut_compression = _check_lut_compression(spec.compression) - spec_bitwidth = lut_compression.index_bitwidth - axis = _identify_compression_axis(tensor) - compressed = _lut_compress_array(tensor.array, axis) - _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) - - # overwrite tensor data with indices - tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) - - # write value buffer - value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, - 2**spec_bitwidth) - value_buffer = model_editor.Buffer(data=value_buffer_data) - model.buffers.append(value_buffer) # Auto-sets value_buffer.index - - # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) - lut_tensor.tensor = spec.tensor - lut_tensor.valueBuffer = value_buffer.index - lut_tensor.indexBitwidth = spec_bitwidth - + tensor = model.subgraphs[tensor_spec.subgraph].tensors[ + tensor_spec.tensor] + + # Currently only one compression method per tensor + if len(tensor_spec.compression) != 1: + raise compressor.CompressionError( + "Each tensor must have exactly one compression method") + + method = tensor_spec.compression[0] + plugin = _get_compressor(method) + original_size = len(tensor.buffer.data) if tensor.buffer.data else 0 + result = plugin.compress(tensor, method) + + compressed_size = len(result.encoded_data) + len(result.ancillary_data) + if compressed_size > original_size: + warnings.warn( + f"Compression of tensor {tensor.name!r} (subgraph " + f"{tensor_spec.subgraph}, tensor {tensor_spec.tensor}) resulted " + f"in expansion: {original_size} bytes -> {compressed_size} bytes " + f"(encoded: {len(result.encoded_data)}, " + f"ancillary: {len(result.ancillary_data)})", + stacklevel=2) + + # Replace tensor data with encoded data + tensor.buffer.data = result.encoded_data + + # Store result for DECODE insertion + compression_results[(tensor_spec.subgraph, tensor_spec.tensor)] = result + + except compressor.CompressionError: + raise except Exception as e: - raise CompressionError(f"error compressing {spec}") from e + raise compressor.CompressionError( + f"error compressing {tensor_spec}") from e - # add compression metadata to model - model.metadata[TFLITE_METADATA_KEY] = metadata.compile() + # Insert DECODE operators into the graph + decode_insert.insert_decode_operators(model, compression_results) # Build the model and apply proper alignment unaligned_model = model.build() diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index ee10a75f36d..6ee80f200d5 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -11,312 +11,109 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Integration tests for the compression system.""" + +import warnings -import bitarray -import bitarray.util import numpy as np import unittest from tflite_micro.tensorflow.lite.micro.compression import compress -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -class TestPackIndices(unittest.TestCase): - - def test_basic_case(self): - indices = np.array([1, 2, 3]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0000]) - self.assertEqual(result, expected_bytes) - - def test_single_element(self): - indices = np.array([10]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_1010]) - self.assertEqual(result, expected_bytes) - - def test_different_bitwidth(self): - indices = np.array([1, 2, 3]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) - self.assertEqual(result, expected_bytes) - - def test_large_numbers(self): - indices = np.array([255, 128, 64]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) - self.assertEqual(result, expected_bytes) - - def test_multidimensional_array(self): - indices = np.array([[1, 2], [3, 4]]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0100]) - self.assertEqual(result, expected_bytes) - - def test_zero_bitwidth(self): - indices = np.array([0, 1, 2]) - bitwidth = 0 - with self.assertRaises(ValueError): - compress._pack_indices(indices, bitwidth) - - def test_empty_array(self): - indices = np.array([]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = b"" - self.assertEqual(result, expected_bytes) - - def test_bitwidth_1(self): - indices = np.array([1, 0, 1, 1, 0, 1]) - bitwidth = 1 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b101101_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_2(self): - indices = np.array([1, 2, 3, 0]) - bitwidth = 2 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b01_10_11_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_3(self): - indices = np.array([1, 3, 5, 7]) - bitwidth = 3 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_5(self): - indices = np.array([1, 2, 16, 31]) - bitwidth = 5 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_7(self): - indices = np.array([1, 64, 127, 32]) - bitwidth = 7 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes( - [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) - self.assertEqual(result, expected_bytes) - - -class TestPackLookupTables(unittest.TestCase): - - def test_int16_positive(self): - tables = [np.array([0x1234, 0x5678], dtype=' tuple[int, bitarray.bitarray, np.ndarray]: - """Helper: extracts the compressed tensor parts for a given spec. - - Returns: - bitwidth - indices - values - """ - subgraph_obj = self.compressed.subgraphs[subgraph] - tensor_obj = subgraph_obj.tensors[tensor] - lut_tensors = self.metadata.subgraphs[subgraph_obj.index].lutTensors - lut_tensor = next(t for t in lut_tensors if t.tensor == tensor_obj.index) - bitwidth = lut_tensor.indexBitwidth - - indices = bitarray.bitarray(buffer=tensor_obj.buffer.data, endian="big") - n_indices = np.prod(tensor_obj.shape) - indices = indices[:n_indices * bitwidth] # trim possible padding - - value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) - - return bitwidth, indices, values - - def _make_indices(self, s: str) -> bitarray.bitarray: - """Helper: makes indices from "01" strings for use as expected values.""" - return bitarray.bitarray(s, endian="big") - - def test_compressed_uint8(self): - bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=0) - self.assertEqual(bitwidth, 4) - - # yapf: disable - expected_indices = self._make_indices(""" - 0000 0001 0010 0011 - 0100 0101 0110 0111 - 1000 1001 1010 1011 - 1100 1101 1110 1111 - """) - # yapf: enable - self.assertEqual(indices, expected_indices) - - expected_values = np.array(range(16), dtype=" [FC1 with weights1] -> output1 + input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2 + + weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which + runs between the two consumers of weights1. + """ + # 4 unique values per tensor for 2-bit LUT compression. Small values avoid + # saturation in chained layers. Different row sums produce varied outputs. + weights1_data = np.array([ + [-1, 0, 0, 1], + [-1, 0, 1, 1], + [-1, 1, 1, 1], + [0, 1, 1, 1], + ], + dtype=np.int8) + weights1 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights1_data, + name="weights1", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + weights2_data = np.array([ + [1, 1, 1, 1], + [1, 1, 2, 2], + [1, 2, 2, 3], + [2, 2, 3, 3], + ], + dtype=np.int8) + weights2 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights2_data, + name="weights2", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + # All tensors need matching quantization for FULLY_CONNECTED + quant = model_editor.Quantization(scales=1.0, zero_points=0) + + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + quantization=quant, + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + quantization=quant, + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + quantization=quant, + ) + intermediate = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="intermediate", + quantization=quant, + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + quantization=quant, + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights1, weights2], + inputs=[input1, input2], + outputs=[output1, output2], + operators=[ + # FC1: uses weights1 + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights1], + outputs=[output1], + ), + # FC2: uses weights2 (runs between FC1 and FC3) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights2], + outputs=[intermediate], + ), + # FC3: uses weights1 (second consumer, after DECODE(weights2)) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[intermediate, weights1], + outputs=[output2], + ), + ], + ) + ]) + return model.build() + + +class AltDecompressionMemoryTest(unittest.TestCase): + """Tests for alternate decompression memory with shared compressed tensors. + + These tests verify correct behavior when compressed tensors are shared + between multiple operators and alternate decompression memory is enabled. + """ + + def test_shared_compressed_tensor_with_alt_memory(self): + """Verify correct results when a shared compressed tensor is used with alt + decompression memory. + + This test uses a graph where a compressed tensor (weights1) is consumed by + two operators (FC1 and FC3), with an intervening DECODE of a different + compressed tensor (weights2) between them. + + The interpreter's alternate decompression memory has a limitation: each + DECODE's Prepare resets the allocation offset to zero. This means all + DECODE outputs are allocated at the same address, so they overwrite each + other. A DECODE output can only be used until the next DECODE runs. + + To work around this limitation, the DECODE insertion code inserts a + separate DECODE immediately before each consumer of a compressed tensor, + rather than sharing one DECODE output among all consumers. + """ + flatbuffer = _build_shared_weights_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, # weights1 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + spec.Tensor( + subgraph=0, + tensor=1, # weights2 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + # Run without alt decompression memory (baseline) + interp_no_alt = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Run with alt decompression memory + interp_with_alt = runtime.Interpreter.from_bytes( + bytes(compressed_fb), + alt_decompression_memory_size=256, + ) + + test_input1 = np.array([[1, 1, 1, 1]], dtype=np.int8) + test_input2 = np.array([[1, 1, 1, 1]], dtype=np.int8) + + interp_no_alt.set_input(test_input1, 0) + interp_no_alt.set_input(test_input2, 1) + interp_no_alt.invoke() + expected1 = interp_no_alt.get_output(0) + expected2 = interp_no_alt.get_output(1) + + interp_with_alt.set_input(test_input1, 0) + interp_with_alt.set_input(test_input2, 1) + interp_with_alt.invoke() + actual1 = interp_with_alt.get_output(0) + actual2 = interp_with_alt.get_output(1) + + np.testing.assert_array_equal( + expected1, actual1, "Output 1 mismatch with alt decompression memory") + np.testing.assert_array_equal( + expected2, actual2, "Output 2 mismatch with alt decompression memory") + + +class HuffmanCompressionTest(unittest.TestCase): + """Integration tests for Huffman compression.""" + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_matches_uncompressed(self): + """Huffman-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_decode_operators_present(self): + """DECODE operators are inserted for Huffman-compressed tensors.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_is_smaller(self): + """Huffman-compressed model is smaller than original.""" + pass + + +class PruningCompressionTest(unittest.TestCase): + """Integration tests for pruning compression.""" + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_matches_uncompressed(self): + """Pruning-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_decode_operators_present(self): + """DECODE operators are inserted for pruning-compressed tensors.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_is_smaller(self): + """Pruning-compressed model is smaller than original.""" + pass + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + unittest.main() diff --git a/tensorflow/lite/micro/compression/compressor.py b/tensorflow/lite/micro/compression/compressor.py new file mode 100644 index 00000000000..3d5a635eb09 --- /dev/null +++ b/tensorflow/lite/micro/compression/compressor.py @@ -0,0 +1,80 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compression plugin interface.""" + +from dataclasses import dataclass +from typing import Protocol + +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + if wrapped_exception: + super().__init__(f"{message}: {str(wrapped_exception)}") + else: + super().__init__(message) + self.original_exception = wrapped_exception + + +@dataclass +class CompressionResult: + """Result of compressing a tensor. + + Attributes: + encoded_data: The compressed tensor data (e.g., packed indices for LUT). + ancillary_data: The complete ancillary data tensor bytes (DCM + type-specific + data). This is the full buffer contents for the ancillary + tensor. + """ + encoded_data: bytes + ancillary_data: bytes + + +class Compressor(Protocol): + """Protocol that compression plugins must implement. + + Each compression method (LUT, Huffman, Pruning) provides a class implementing + this protocol. The compress() function uses duck typing to call the plugin. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """The DecodeType constant for this compression method.""" + ... + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> CompressionResult: + """Compress a tensor according to the specified method. + + Args: + tensor: The tensor to compress. Must have data (tensor.array is not None) + and quantization parameters for axis inference. + method: The compression method spec (e.g., LookUpTableCompression). + + Returns: + CompressionResult with encoded tensor data and ancillary data bytes. + + Raises: + CompressionError: If compression fails (e.g., too many unique values + for specified bitwidth, missing quantization, etc.). + """ + ... diff --git a/tensorflow/lite/micro/compression/decode.py b/tensorflow/lite/micro/compression/decode.py new file mode 100644 index 00000000000..df8428310a3 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode.py @@ -0,0 +1,240 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE compression module.""" + +# Implements the DECODE operator compression scheme described in the +# "TFLM DECODE Operator Design" document, revised May 20, 2025. +# +# The DECODE operator transforms an encoded tensor, alongside a paired +# ancillary data tensor, into a tensor ready for use as input to any +# operator. For example, an encoded tensor might contain compressed +# data, while the paired ancillary data tensor holds the information +# necessary for decompression. The DECODE operator's output is a fully +# decompressed tensor. +# +# DECODE operators are inserted into the TfLite model subgraph +# immediately before each operation that uses a decodable tensor as +# input. +# +# Ancillary Data Tensor +# +# The ancillary data tensor contains the information necessary for +# decoding. It begins with a 16-byte DECODE Common Metadata (DCM) +# header, followed by decode-type-specific ancillary data. +# +# DECODE Common Metadata (DCM) +# +# Byte 0: Decode type +# 0-127: TFLM-supported decode operations (see below) +# 128-255: Custom operations requiring application-registered +# handlers +# +# Supported decode types: +# +# 0: LUT decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 1: Huffman decompression using Xtensa format decode tables +# INT8 and INT16 tensor types only, in reference and optimized +# code. +# +# 2: Pruning decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 3-127: Reserved +# +# 128-255: Custom decode types +# Requires user-supplied encoding module and decoding ancillary +# data. +# +# Byte 1: DCM version (currently 1) +# +# Bytes 2-3: Reserved +# +# Bytes 4-15: User-defined +# Used by TFLM decode types to avoid requiring additional alignment +# of metadata or ancillary data. +# +# The 16-byte DCM size ensures that subsequent metadata and ancillary +# data are 128-bit aligned, which is required for some optimized +# decoding operations such as Xtensa LUT decompression. +# +# For TFLM decode types, ancillary data starts immediately after the +# DCM. For custom decode types, the location is determined by +# user-defined metadata. + +from dataclasses import dataclass +from typing import Protocol + + +class DecodeType: + """Decode operation type (0-255). + + Use predefined constants for built-in types or DecodeType.custom() + for custom types: + DecodeType.LUT # 0 + DecodeType.HUFFMAN # 1 + DecodeType.PRUNING # 2 + DecodeType.custom(200) # Custom type 128-255 + """ + + # Built-in decode types (class variables set after class definition) + LUT: 'DecodeType' + HUFFMAN: 'DecodeType' + PRUNING: 'DecodeType' + + def __init__(self, code: int, name: str = None): + """Initialize DecodeType. + + Args: + code: Integer code 0-255 + name: Optional name for the type. If not provided: + - Codes 0-127: Named "TYPE_{code}" + - Codes 128-255: Named "CUSTOM_{code}" + """ + if not 0 <= code <= 255: + raise ValueError(f"Decode type must be 0-255, got {code}") + self.code = code + + # Auto-generate name if not provided + if name is None: + self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}" + else: + self.name = name + + self._is_custom = code >= 128 + + @property + def is_custom(self) -> bool: + """True if this is a custom decode type (128-255).""" + return self._is_custom + + @classmethod + def custom(cls, code: int) -> 'DecodeType': + """Create custom decode type (128-255). + + Args: + code: Integer code 128-255 + + Returns: + DecodeType with name CUSTOM_{code} + """ + if not 128 <= code <= 255: + raise ValueError(f"Custom decode type must be 128-255, got {code}") + return cls(code) + + def __int__(self): + """Convert to integer for serialization.""" + return self.code + + def __eq__(self, other): + if isinstance(other, DecodeType): + return self.code == other.code + return self.code == other + + def __repr__(self): + return f"DecodeType.{self.name}({self.code})" + + +# Define built-in decode type constants +DecodeType.LUT = DecodeType(0, "LUT") +DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN") +DecodeType.PRUNING = DecodeType(2, "PRUNING") + + +@dataclass +class DecodeCommonMetadata: + """16-byte DECODE Common Metadata (DCM) header. + + Attributes: + decode_type: Decode operation type. Use DecodeType constants or + DecodeType.custom(code) for custom types. + version: DCM version (currently 1). + user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM + decode types to avoid requiring additional alignment of metadata + or ancillary data. + """ + decode_type: DecodeType + version: int = 1 + user_data: bytes = b'\x00' * 12 + + def to_bytes(self) -> bytes: + """Serialize DCM to 16-byte sequence.""" + decode_code = int(self.decode_type) + if not 0 <= self.version <= 255: + raise ValueError(f"version must be 0-255, got {self.version}") + if len(self.user_data) < 12: + # Pad with zeros if user_data is too short + user_data = self.user_data + b'\x00' * (12 - len(self.user_data)) + else: + user_data = self.user_data[:12] + + result = bytearray(16) + result[0] = decode_code + result[1] = self.version + # bytes 2-3 remain zero (reserved) + result[4:16] = user_data + return bytes(result) + + +class AncillaryDataSerializer(Protocol): + """Protocol for objects that can serialize ancillary data.""" + + def to_bytes(self) -> bytes: + ... + + +@dataclass +class AncillaryDataTensor: + """Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data. + + The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte + DCM header, followed by decode-type-specific ancillary data. + + Attributes: + dcm: The DECODE Common Metadata header. + ancillary_data: The decode-type-specific ancillary data, either as raw bytes + or as an object implementing the AncillaryDataSerializer + protocol. May be None if only the DCM is needed. + """ + dcm: DecodeCommonMetadata + ancillary_data: AncillaryDataSerializer | bytes | None = None + + def with_ancillary_data( + self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor': + """Create new ADT with ancillary data added. + + Args: + data: Ancillary data to add, either as raw bytes or as an object + implementing AncillaryDataSerializer. + + Returns: + New AncillaryDataTensor with the specified ancillary data. + """ + return AncillaryDataTensor(self.dcm, data) + + def to_bytes(self) -> bytes: + """Serialize entire ADT to bytes. + + Returns: + Byte sequence containing DCM followed by ancillary data (if present). + """ + dcm_bytes = self.dcm.to_bytes() + if self.ancillary_data is None: + return dcm_bytes + if isinstance(self.ancillary_data, bytes): + return dcm_bytes + self.ancillary_data + return dcm_bytes + self.ancillary_data.to_bytes() diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py new file mode 100644 index 00000000000..fa91896e538 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -0,0 +1,280 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE operator insertion into TFLite model graphs. + +This module inserts DECODE operators into a compressed model. DECODE operators +transform encoded tensors (with their paired ancillary data tensors) into +tensors ready for use by downstream operators. + +The DECODE operator is registered as a custom operator named "TFLM_DECODE". +Each DECODE output requires two inputs: the encoded tensor and the ancillary +data tensor (containing the DCM header and decode-type-specific data). +""" + +import warnings +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + +# Custom operator name for DECODE +DECODE_CUSTOM_OP_NAME = "TFLM_DECODE" + + +@dataclass +class _CompressedTensorInfo: + """Information about a compressed tensor for DECODE insertion.""" + subgraph_idx: int + tensor_idx: int + tensor: model_editor.Tensor + encoded_data: bytes + ancillary_data: bytes + consumers: list[model_editor.Operator] + + +def _find_tensor_consumers( + subgraph: model_editor.Subgraph, + tensor: model_editor.Tensor, +) -> list[model_editor.Operator]: + """Find all operators in subgraph that use tensor as an input.""" + consumers = [] + for op in subgraph.operators: + if tensor in op.inputs: + consumers.append(op) + return consumers + + +def _create_ancillary_tensor( + ancillary_data: bytes, + original_tensor: model_editor.Tensor, +) -> model_editor.Tensor: + """Create an ancillary data tensor for a compressed tensor. + + Args: + ancillary_data: The complete ancillary data (DCM + type-specific data). + original_tensor: The original tensor being decoded, for naming. + + Returns: + A new Tensor containing the ancillary data. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_ancillary" + + return model_editor.Tensor( + shape=(len(ancillary_data), ), + dtype=tflite.TensorType.UINT8, + data=ancillary_data, + name=name, + ) + + +def _create_output_tensor( + original_tensor: model_editor.Tensor, ) -> model_editor.Tensor: + """Create the output tensor for a DECODE operator. + + The output tensor has the same shape, dtype, and quantization as the + original tensor would have when decoded. It has no data---the DECODE + operator produces its values at runtime. + + Args: + original_tensor: The original tensor being decoded. + + Returns: + A new Tensor for the DECODE output. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_decoded" + + return model_editor.Tensor( + shape=original_tensor.shape, + dtype=original_tensor.dtype, + quantization=original_tensor.quantization, + name=name, + ) + + +def _rewire_consumers( + consumers: list[model_editor.Operator], + old_tensor: model_editor.Tensor, + new_tensor: model_editor.Tensor, +) -> None: + """Replace old_tensor with new_tensor in all consumer inputs.""" + for consumer in consumers: + consumer.inputs = [ + new_tensor if t is old_tensor else t for t in consumer.inputs + ] + + +def _rewrite_encoded_tensor( + tensor: model_editor.Tensor, + encoded_data: bytes, +) -> None: + """Rewrite a compressed tensor to hold encoded data. + + The original tensor contained uncompressed values with quantization. After + compression, it holds packed indices (or other encoded form) as raw bytes. + This function updates the tensor in place to reflect its new role. + + Args: + tensor: The tensor to rewrite. + encoded_data: The compressed/encoded data bytes. + """ + tensor.shape = (len(encoded_data), ) + tensor.dtype = tflite.TensorType.UINT8 + tensor.quantization = None + tensor.buffer.data = encoded_data + + +def insert_decode_operators( + model: model_editor.Model, + compression_results: dict[tuple[int, int], compressor.CompressionResult], +) -> None: + """Insert DECODE operators for all compressed tensors. + + This function modifies the model in-place, inserting DECODE operators + before any operator that uses a compressed tensor as input. + + A separate DECODE is inserted before each consumer, rather than sharing one + DECODE output among all consumers. This is required because the interpreter's + alternate decompression memory resets its allocation offset for each DECODE's + Prepare, causing all DECODE outputs to be allocated at the same address. If + two consumers share one DECODE and another DECODE runs between them, the + intervening DECODE overwrites the shared output, corrupting data for the + second consumer. + + For each consumer of a compressed tensor: + 1. Create an ancillary data tensor containing DCM + type-specific data + 2. Create an output tensor with the same shape/dtype as the decoded tensor + 3. Insert a DECODE operator immediately before the consumer + 4. Rewire the consumer to use the DECODE output + + Args: + model: The model to modify in-place. + compression_results: Map from (subgraph_idx, tensor_idx) to the + CompressionResult containing ancillary_data. + """ + # Group compressed tensors by subgraph + by_subgraph: dict[int, list[_CompressedTensorInfo]] = defaultdict(list) + + for (sg_idx, tensor_idx), result in compression_results.items(): + subgraph = model.subgraphs[sg_idx] + tensor = subgraph.tensors[tensor_idx] + consumers = _find_tensor_consumers(subgraph, tensor) + + if not consumers: + # Check if tensor is a subgraph output + is_output = tensor in subgraph.outputs + if is_output: + # TODO: Handle compressed tensors that are subgraph outputs. + # This occurs in multi-subgraph models using IF/WHILE where a + # compressed tensor flows out of a subgraph. + raise NotImplementedError( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) is a subgraph output with no consumers. " + "Compressed subgraph outputs are not yet supported.") + else: + warnings.warn( + f"Compressed tensor {tensor.name!r} (subgraph {sg_idx}, " + f"tensor {tensor_idx}) has no consumers and is not a subgraph " + "output. No DECODE operator will be inserted.", + stacklevel=2) + continue + + info = _CompressedTensorInfo( + subgraph_idx=sg_idx, + tensor_idx=tensor_idx, + tensor=tensor, + encoded_data=result.encoded_data, + ancillary_data=result.ancillary_data, + consumers=consumers, + ) + by_subgraph[sg_idx].append(info) + + # Process each subgraph + for sg_idx, tensor_infos in by_subgraph.items(): + subgraph = model.subgraphs[sg_idx] + + # Group tensor infos by consumer so multiple compressed inputs to the + # same operator get batched into a single DECODE. + consumer_to_infos: dict[model_editor.Operator, list[_CompressedTensorInfo]] + consumer_to_infos = defaultdict(list) + for info in tensor_infos: + for consumer in info.consumers: + if info not in consumer_to_infos[consumer]: + consumer_to_infos[consumer].append(info) + + # Sort consumers by position in reverse so insertions don't invalidate + # earlier positions. + sorted_consumers = sorted( + consumer_to_infos.keys(), + key=lambda op: subgraph.operators.index(op), + reverse=True, + ) + + # Cache ancillary tensors by original tensor to avoid duplicates. Each + # DECODE needs its own output tensor, but ancillary data is identical for + # all DECODEs of the same compressed tensor. + ancillary_cache: dict[model_editor.Tensor, model_editor.Tensor] = {} + + # Track tensors to rewrite after all output tensors are created, since + # _create_output_tensor reads the original tensor's shape/dtype/quantization. + tensors_to_rewrite: dict[model_editor.Tensor, bytes] = {} + + for consumer in sorted_consumers: + decode_inputs = [] + decode_outputs = [] + + for info in consumer_to_infos[consumer]: + # Reuse or create ancillary data tensor + if info.tensor not in ancillary_cache: + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + ancillary_cache[info.tensor] = ancillary_tensor + tensors_to_rewrite[info.tensor] = info.encoded_data + else: + ancillary_tensor = ancillary_cache[info.tensor] + + # Create output tensor (one per compressed input) + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + decode_inputs.extend([info.tensor, ancillary_tensor]) + decode_outputs.append(output_tensor) + + # Rewire this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) + + # Create single DECODE operator for all compressed inputs + decode_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code=DECODE_CUSTOM_OP_NAME, + inputs=decode_inputs, + outputs=decode_outputs, + ) + + # Insert DECODE immediately before this consumer + insert_pos = subgraph.operators.index(consumer) + subgraph.operators.insert(insert_pos, decode_op) + + # Rewrite encoded tensors after all output tensors are created + for tensor, encoded_data in tensors_to_rewrite.items(): + _rewrite_encoded_tensor(tensor, encoded_data) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py new file mode 100644 index 00000000000..60965b46676 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -0,0 +1,559 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for DECODE operator insertion.""" + +import unittest +import warnings + +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_simple_fc_model(): + """Build a simple model with one FC operator and compressible weights.""" + # yapf: disable + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]], dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + # yapf: enable + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model + + +def _build_shared_weights_model(): + """Build model where one tensor is used by multiple operators.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="shared_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights], + outputs=[output1], + ), + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights], + outputs=[output2], + ), + ], + ) + ]) + return model + + +def _make_dummy_ancillary_data(bitwidth=4) -> bytes: + """Create dummy ancillary data for testing.""" + n_entries = 1 << bitwidth + value_tables = bytes(range(1, n_entries + 1)) + value_tables += b'\x00' * ((-len(value_tables)) % 16) + + lut_data = lut.LutAncillaryData( + bitwidth=bitwidth, + value_table_stride=n_entries, + value_tables=value_tables, + ) + dcm = decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.LUT, + user_data=lut_data.to_user_data(), + ) + return dcm.to_bytes() + lut_data.to_bytes() + + +class TestDecodeInsertion(unittest.TestCase): + """Tests for insert_decode_operators function.""" + + def test_insert_single_decode_operator(self): + """DECODE operator inserted before FC that uses compressed weights.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Create compression result + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + # Insert DECODE operators + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 2 operators: DECODE then FC + self.assertEqual(len(sg.operators), 2) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[0].custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + def test_decode_inputs_structure(self): + """DECODE operator has correct inputs: encoded tensor + ancillary.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + + # DECODE has 2 inputs + self.assertEqual(len(decode_op.inputs), 2) + # First input is the encoded tensor (original weights) + self.assertIs(decode_op.inputs[0], weights_tensor) + # Second input is ancillary tensor + self.assertEqual(decode_op.inputs[1].dtype, tflite.TensorType.UINT8) + + def test_decode_output_structure(self): + """DECODE operator output has correct shape and dtype.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + # Save original properties before rewrite + original_shape = weights_tensor.shape + original_dtype = weights_tensor.dtype + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + output = decode_op.outputs[0] + + # Output matches original (pre-rewrite) tensor shape and dtype + self.assertEqual(output.shape, original_shape) + self.assertEqual(output.dtype, original_dtype) + + def test_consumer_rewired_to_decode_output(self): + """FC operator input rewired to use DECODE output.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + fc_op = model.subgraphs[0].operators[1] + + # FC's second input (weights) should now be DECODE's output + self.assertIs(fc_op.inputs[1], decode_op.outputs[0]) + # Original weights tensor should NOT be in FC inputs + self.assertNotIn(weights_tensor, fc_op.inputs) + + def test_shared_tensor_decode_per_consumer(self): + """Tensor used by multiple ops gets separate DECODE for each consumer.""" + model = _build_shared_weights_model() + weights_tensor = model.subgraphs[0].tensor_by_name("shared_weights") + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC) + self.assertEqual(len(sg.operators), 4) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[3].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + decode_op1 = sg.operators[0] + fc_op1 = sg.operators[1] + decode_op2 = sg.operators[2] + fc_op2 = sg.operators[3] + + # Each FC should use its own DECODE's output + self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0]) + self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0]) + # The two DECODEs should have different outputs + self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0]) + # The two DECODEs should share the same ancillary tensor + self.assertIs(decode_op1.inputs[1], decode_op2.inputs[1]) + + def test_ancillary_tensor_contains_dcm(self): + """Ancillary tensor data contains valid DCM header.""" + model = _build_simple_fc_model() + + ancillary_data = _make_dummy_ancillary_data() + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=ancillary_data, + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary_tensor = decode_op.inputs[1] + + # Ancillary tensor data should match what we provided + self.assertEqual(bytes(ancillary_tensor.array), ancillary_data) + + # Verify DCM header + dcm_bytes = ancillary_tensor.array[:16] + self.assertEqual(dcm_bytes[0], 0) # decode_type = LUT + self.assertEqual(dcm_bytes[1], 1) # DCM version + + def test_no_consumers_no_decode(self): + """Tensor with no consumers gets no DECODE operator and emits warning.""" + # Create model where compressed tensor is not used as input + unused_tensor = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="unused", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + other_weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="other_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[unused_tensor, other_weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, other_weights], + outputs=[output_t], + ) + ], + ) + ]) + + # Compress the unused tensor + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + decode_insert.insert_decode_operators(model, compression_results) + + # Should emit a warning about no consumers + self.assertEqual(len(w), 1) + self.assertIn("no consumers", str(w[0].message)) + self.assertIn("unused", str(w[0].message)) + + # Should still have just 1 operator (no DECODE inserted) + self.assertEqual(len(model.subgraphs[0].operators), 1) + + def test_tensor_naming(self): + """Output and ancillary tensors get appropriate names.""" + model = _build_simple_fc_model() + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary = decode_op.inputs[1] + output = decode_op.outputs[0] + + self.assertEqual(ancillary.name, "weights_ancillary") + self.assertEqual(output.name, "weights_decoded") + + def test_multiple_compressed_inputs_batched(self): + """CONCATENATION with two compressed inputs gets one batched DECODE.""" + weights_a = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_a", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + weights_b = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights_b", + quantization=model_editor.Quantization(scales=0.25, zero_points=0), + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights_a, weights_b], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights_a, weights_b], + operators=[concat_op], + ) + ]) + + ancillary_a = _make_dummy_ancillary_data(bitwidth=2) + ancillary_b = _make_dummy_ancillary_data(bitwidth=4) + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=ancillary_a, + ), + (0, 1): + compressor.CompressionResult( + encoded_data=b'\x02\x03', + ancillary_data=ancillary_b, + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + self.assertEqual(decode_op.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(decode_op.custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + + # DECODE has 4 inputs (enc_a, anc_a, enc_b, anc_b) and 2 outputs + self.assertEqual(len(decode_op.inputs), 4) + self.assertEqual(len(decode_op.outputs), 2) + + # Each ancillary tensor carries its own distinct data + self.assertNotEqual(ancillary_a, ancillary_b) + self.assertEqual(bytes(decode_op.inputs[1].array), ancillary_a) + self.assertEqual(bytes(decode_op.inputs[3].array), ancillary_b) + + # CONCATENATION rewired to DECODE outputs + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], decode_op.outputs[1]) + + def test_mixed_compressed_and_uncompressed_inputs(self): + """CONCATENATION with one compressed and one plain input.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + plain = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.zeros((4, 4), dtype=np.int8), + name="plain", + ) + output_t = model_editor.Tensor( + shape=(4, 8), + dtype=tflite.TensorType.INT8, + name="output", + ) + + concat_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CONCATENATION, + inputs=[weights, plain], + outputs=[output_t], + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights, plain], + operators=[concat_op], + ) + ]) + + # Only compress weights, not plain + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x01', + ancillary_data=_make_dummy_ancillary_data(), + ), + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # One DECODE + one CONCATENATION + self.assertEqual(len(sg.operators), 2) + decode_op = sg.operators[0] + + # DECODE has 2 inputs and 1 output (only the compressed tensor) + self.assertEqual(len(decode_op.inputs), 2) + self.assertEqual(len(decode_op.outputs), 1) + + # CONCATENATION: first input rewired to DECODE output, second unchanged + self.assertIs(sg.operators[1].inputs[0], decode_op.outputs[0]) + self.assertIs(sg.operators[1].inputs[1], plain) + + def test_encoded_tensor_rewritten(self): + """Compressed tensor is rewritten with encoded data, UINT8 type, no quant.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensor_by_name("weights") + + encoded_data = b'\xAB\xCD\xEF' + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + # Original tensor should be rewritten + self.assertEqual(weights_tensor.shape, (len(encoded_data), )) + self.assertEqual(weights_tensor.dtype, tflite.TensorType.UINT8) + self.assertIsNone(weights_tensor.quantization) + self.assertEqual(weights_tensor.buffer.data, encoded_data) + + +class TestHelperFunctions(unittest.TestCase): + """Tests for internal helper functions.""" + + def test_find_tensor_consumers(self): + """_find_tensor_consumers finds all ops using a tensor.""" + model = _build_shared_weights_model() + sg = model.subgraphs[0] + weights = sg.tensor_by_name("shared_weights") + + consumers = decode_insert._find_tensor_consumers(sg, weights) + + self.assertEqual(len(consumers), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/lite/micro/compression/decode_test.py b/tensorflow/lite/micro/compression/decode_test.py new file mode 100644 index 00000000000..eca3b42b2b4 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_test.py @@ -0,0 +1,155 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import decode + + +class TestDecodeCommonMetadata(unittest.TestCase): + + def testBasicSerialization(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + result = dcm.to_bytes() + + # Should be exactly 16 bytes + self.assertEqual(len(result), 16) + + # Byte 0: decode_type + self.assertEqual(result[0], 0) + + # Byte 1: version (default 1) + self.assertEqual(result[1], 1) + + # Bytes 2-3: reserved (should be zero) + self.assertEqual(result[2], 0) + self.assertEqual(result[3], 0) + + # Bytes 4-15: user_data (default all zeros) + self.assertEqual(result[4:16], b'\x00' * 12) + + def testCustomVersion(self): + dcm = decode.DecodeCommonMetadata(decode_type=1, version=2) + result = dcm.to_bytes() + + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 2) + + def testUserData(self): + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data) + + def testUserDataPadding(self): + # User data shorter than 12 bytes should be padded with zeros + user_data = b'\x01\x02\x03' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + expected = b'\x01\x02\x03' + b'\x00' * 9 + self.assertEqual(result[4:16], expected) + + def testUserDataTruncation(self): + # User data longer than 12 bytes should be truncated + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data[:12]) + + def testDecodeTypeRange(self): + # Valid decode types: 0-255 + decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT).to_bytes() + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(127)).to_bytes() + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.custom(255)).to_bytes() + + # Invalid decode types should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(-1)).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType(256)).to_bytes() + + def testVersionRange(self): + # Valid versions: 0-255 + decode.DecodeCommonMetadata(decode_type=0, version=0).to_bytes() + decode.DecodeCommonMetadata(decode_type=0, version=255).to_bytes() + + # Invalid versions should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=-1).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=256).to_bytes() + + +class TestAncillaryDataTensor(unittest.TestCase): + + def testDcmOnly(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + adt = decode.AncillaryDataTensor(dcm) + result = adt.to_bytes() + + # Should be just the 16-byte DCM + self.assertEqual(len(result), 16) + self.assertEqual(result, dcm.to_bytes()) + + def testWithBytesAncillaryData(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.HUFFMAN) + ancillary = b'\xaa\xbb\xcc\xdd' + adt = decode.AncillaryDataTensor(dcm, ancillary) + result = adt.to_bytes() + + # Should be DCM + ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithAncillaryDataMethod(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.PRUNING) + adt = decode.AncillaryDataTensor(dcm) + + ancillary = b'\x11\x22\x33\x44' + adt_with_data = adt.with_ancillary_data(ancillary) + result = adt_with_data.to_bytes() + + # Original ADT should be unchanged + self.assertEqual(adt.to_bytes(), dcm.to_bytes()) + + # New ADT should have ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithSerializerProtocol(self): + # Test with an object that implements AncillaryDataSerializer + class MockSerializer: + + def to_bytes(self): + return b'\xff\xee\xdd\xcc' + + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType(3)) + serializer = MockSerializer() + adt = decode.AncillaryDataTensor(dcm, serializer) + result = adt.to_bytes() + + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], b'\xff\xee\xdd\xcc') + + +if __name__ == '__main__': + unittest.main() diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py new file mode 100644 index 00000000000..e539827eae4 --- /dev/null +++ b/tensorflow/lite/micro/compression/huffman.py @@ -0,0 +1,60 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Huffman compression plugin (stub). + +This module provides a placeholder for Huffman compression using Xtensa-format +decode tables. The actual implementation is not yet available. + +Supported tensor types (when implemented): INT8, INT16 +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class HuffmanCompressor(compressor.Compressor): + """Huffman compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual Huffman + compression algorithm using Xtensa-format decode tables is not yet + implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.HUFFMAN.""" + return decode.DecodeType.HUFFMAN + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using Huffman encoding. + + Args: + tensor: The tensor to compress. + method: Must be a HuffmanCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Huffman compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py new file mode 100644 index 00000000000..991288f54cc --- /dev/null +++ b/tensorflow/lite/micro/compression/lut.py @@ -0,0 +1,318 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LUT (Look-Up Table) compression plugin.""" + +import sys +from dataclasses import dataclass, field +from typing import Optional + +import bitarray +import bitarray.util +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +@dataclass +class LutCompressedArray: + """Intermediate representation of LUT-compressed data. + + Attributes: + compression_axis: The axis along which compression was performed, or None + for per-tensor compression. + lookup_tables: List of value lookup tables. One table for per-tensor + compression, or one per channel for per-channel compression. + indices: Array of indices into the lookup tables, same shape as original. + """ + compression_axis: Optional[int] = None + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None or self.indices.size == 0: + raise ValueError("No indices to compute bitwidth from") + max_index = int(np.max(self.indices)) + return max_index.bit_length() or 1 + + +@dataclass +class LutAncillaryData: + """LUT-specific ancillary data matching C++ decode_state_lut.cc format. + + The LUT ancillary data uses the DCM user_data bytes (4-15) plus value tables: + - Byte 4: LUT version (currently 1) + - Byte 5: Params (lower 3 bits = bitwidth, 1-7) + - Byte 6: Value table channel stride (elements per channel) + - Bytes 7-15: Reserved (zeros) + - Bytes 16+: Value tables (concatenated, stride elements per channel) + + Attributes: + lut_version: LUT format version (currently 1). + bitwidth: Number of bits per index (1-7). + value_table_stride: Number of elements per channel in value tables. + value_tables: Packed value table data following the DCM. + """ + lut_version: int = 1 + bitwidth: int = 4 + value_table_stride: int = 16 + value_tables: bytes = b'' + + def __post_init__(self): + if not 1 <= self.bitwidth <= 7: + raise ValueError(f"bitwidth must be 1-7, got {self.bitwidth}") + if not 0 <= self.value_table_stride <= 128: + raise ValueError( + f"value_table_stride must be 0-128, got {self.value_table_stride}") + + def to_user_data(self) -> bytes: + """Serialize to 12-byte user_data for DCM bytes 4-15.""" + user_data = bytearray(12) + user_data[0] = self.lut_version + user_data[1] = self.bitwidth & 0x07 + user_data[2] = self.value_table_stride + # Bytes 3-11 (DCM bytes 7-15) remain zero (reserved) + return bytes(user_data) + + def to_bytes(self) -> bytes: + """Serialize for use as AncillaryDataTensor.ancillary_data.""" + # This returns the type-specific data that follows the DCM header. + # For LUT, that's just the value tables since user_data is in DCM. + return self.value_tables + + +def compress_array(tensor: np.ndarray, + axis: Optional[int]) -> LutCompressedArray: + """Compresses the given tensor using lookup tables. + + Args: + tensor: The tensor to be compressed. + axis: The axis along which to compress. If an axis is given, a lookup table + is created for each slice along the axis. If axis is None, a single + lookup table is used for the entire tensor. + + Compressing a tensor with a lookup table per slice along a particular + axis is analogous to quantizing a tensor with different quantization + parameters per slice along a particular axis (dimension). + + Returns: + LutCompressedArray containing lookup tables and indices. + """ + compressed = LutCompressedArray() + compressed.compression_axis = axis + + if axis is None: + # Compute unique values and indices for the entire tensor + values, indices = np.unique(tensor, return_inverse=True) + compressed.lookup_tables.append(values) + compressed.indices = indices.reshape(tensor.shape) + else: + # Iterate over slices along the compression axis + slice_indices = [] + for slice in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(slice, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(slice.shape) + slice_indices.append(indices) + + # Reconstruct a tensor of indices from the slices + stacked = np.stack(slice_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: + """Determines the axis along which to compress. + + The axis along which to compress is inferred from the tensor's quantization + parameters. Unquantized tensors use per-tensor compression. + + Args: + tensor: The tensor to analyze. + + Returns: + The axis along which to compress, or None to indicate one value table for + the entire tensor. + + Raises: + CompressionError: If the axis cannot be determined from quantization. + """ + q = tensor.quantization + if q is None: + return None + + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + + if quantization_channels == 1: + return None + + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis + + raise compressor.CompressionError( + "Invalid or no quantization parameters from which to " + "infer the axis along which tensor should be compressed.") + + +def check_bitwidth(compressed: int, specified: int, tensor_spec: spec.Tensor): + """Validates that the specified bitwidth is sufficient. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + + Args: + compressed: The bitwidth required by the compressed data. + specified: The bitwidth specified in the compression spec. + tensor_spec: The tensor spec, for error messages. + + Raises: + CompressionError: If specified bitwidth is too small. + """ + if compressed > specified: + raise compressor.CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {tensor_spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in " + f"{tensor_spec}", + file=sys.stderr) + + +def pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. + + Args: + indices: Array of indices to pack. + bitwidth: Number of bits per index. + + Returns: + Packed bytes with indices in big-endian bit order. + """ + endianness = "big" + bits = bitarray.bitarray(endian=endianness) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) + return bits.tobytes() + + +def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: + """Packs the value tables of a LutCompressedArray. + + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables are concatenated. + + Args: + tables: List of numpy arrays containing lookup table values. + table_len: Length to pad each table to (typically 2**bitwidth). + + Returns: + Packed bytes containing all tables concatenated. + """ + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + return bytes(buffer) + + +class LutCompressor(compressor.Compressor): + """LUT compression plugin implementing the Compressor protocol.""" + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.LUT.""" + return decode.DecodeType.LUT + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using LUT compression. + + Args: + tensor: The tensor to compress. + method: Must be a LookUpTableCompression instance. + + Returns: + CompressionResult with packed indices and ancillary data. + + Raises: + CompressionError: If compression fails. + """ + if not isinstance(method, spec.LookUpTableCompression): + raise compressor.CompressionError( + f"LutCompressor requires LookUpTableCompression, got {type(method)}") + + if tensor.array is None: + raise compressor.CompressionError("Tensor has no data to compress") + + spec_bitwidth = method.index_bitwidth + axis = identify_compression_axis(tensor) + compressed = compress_array(tensor.array, axis) + # Note: check_bitwidth requires a spec.Tensor but we don't have it here. + # We'll do a simpler check. + actual_bitwidth = compressed.index_bitwidth + if actual_bitwidth > spec_bitwidth: + raise compressor.CompressionError( + f"index_bitwidth too small: {actual_bitwidth} bits needed, " + f"but only {spec_bitwidth} specified") + elif actual_bitwidth < spec_bitwidth: + print( + f"warning: index_bitwidth larger than necessary: only " + f"{actual_bitwidth} bits needed, but {spec_bitwidth} specified", + file=sys.stderr) + + # Pack indices into bytes + encoded_data = pack_indices(compressed.indices, spec_bitwidth) + + # Pack value tables + table_len = max(len(t) for t in compressed.lookup_tables) + value_tables_bytes = pack_lookup_tables(compressed.lookup_tables, + table_len) + + # Build ancillary data + lut_data = LutAncillaryData( + lut_version=1, + bitwidth=spec_bitwidth, + value_table_stride=table_len, + value_tables=value_tables_bytes, + ) + + # Build complete ancillary data tensor bytes: DCM header + value tables + dcm = decode.DecodeCommonMetadata( + decode_type=self.decode_type, + user_data=lut_data.to_user_data(), + ) + ancillary_data = dcm.to_bytes() + lut_data.to_bytes() + + return compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=ancillary_data, + ) diff --git a/tensorflow/lite/micro/compression/lut_test.py b/tensorflow/lite/micro/compression/lut_test.py new file mode 100644 index 00000000000..d01dcfd4260 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut_test.py @@ -0,0 +1,405 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for LUT compression plugin.""" + +import numpy as np +import unittest + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class TestCompressArray(unittest.TestCase): + """Tests for the compress_array function.""" + + def test_per_tensor_basic(self): + """Per-tensor compression extracts unique values.""" + array = np.array([1, 2, 1, 2, 3, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertIsNone(compressed.compression_axis) + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1, 2, 3]) + # Indices should map back to original values + reconstructed = compressed.lookup_tables[0][compressed.indices] + np.testing.assert_array_equal(reconstructed, array) + + def test_per_tensor_preserves_shape(self): + """Indices array has same shape as input.""" + # yapf: disable + array = np.array([[1, 2], + [3, 1], + [2, 3]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(compressed.indices.shape, array.shape) + + def test_per_channel_axis0(self): + """Per-channel compression along axis 0.""" + # Each row gets its own value table + # yapf: disable + array = np.array([[1, 1, 1], + [5, 5, 5], + [9, 9, 9]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=0) + + self.assertEqual(compressed.compression_axis, 0) + self.assertEqual(len(compressed.lookup_tables), 3) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + np.testing.assert_array_equal(compressed.lookup_tables[2], [9]) + + def test_per_channel_axis1(self): + """Per-channel compression along axis 1.""" + # Each column gets its own value table + # yapf: disable + array = np.array([[1, 5], + [1, 5], + [1, 5]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=1) + + self.assertEqual(compressed.compression_axis, 1) + self.assertEqual(len(compressed.lookup_tables), 2) + np.testing.assert_array_equal(compressed.lookup_tables[0], [1]) + np.testing.assert_array_equal(compressed.lookup_tables[1], [5]) + + def test_single_value(self): + """Array with single unique value.""" + array = np.array([7, 7, 7, 7], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(len(compressed.lookup_tables), 1) + np.testing.assert_array_equal(compressed.lookup_tables[0], [7]) + np.testing.assert_array_equal(compressed.indices, [0, 0, 0, 0]) + + def test_bitwidth_calculation(self): + """Index bitwidth is computed correctly.""" + # 3 unique values -> 2 bits needed + array = np.array([0, 1, 2], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 4 unique values -> 2 bits needed + array = np.array([0, 1, 2, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 5 unique values -> 3 bits needed + array = np.array([0, 1, 2, 3, 4], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 3) + + def test_bitwidth_single_value(self): + """Single unique value requires 1 bit.""" + array = np.array([42, 42, 42], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 1) + + +class TestPackIndices(unittest.TestCase): + """Tests for the pack_indices function.""" + + def test_4bit_packing(self): + """Pack indices into 4-bit fields.""" + indices = np.array([1, 2, 3, 0]) + result = lut.pack_indices(indices, bitwidth=4) + # Big-endian: 0001 0010 | 0011 0000 = 0x12 0x30 + self.assertEqual(result, bytes([0x12, 0x30])) + + def test_2bit_packing(self): + """Pack indices into 2-bit fields.""" + indices = np.array([0, 1, 2, 3]) + result = lut.pack_indices(indices, bitwidth=2) + # Big-endian: 00 01 10 11 = 0x1B + self.assertEqual(result, bytes([0x1B])) + + def test_3bit_packing(self): + """Pack indices into 3-bit fields.""" + indices = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + result = lut.pack_indices(indices, bitwidth=3) + # 000 001 010 011 | 100 101 110 111 + # 00000101 | 00111001 | 01110111 = 0x05 0x39 0x77 + self.assertEqual(result, bytes([0x05, 0x39, 0x77])) + + def test_1bit_packing(self): + """Pack indices into 1-bit fields.""" + indices = np.array([0, 1, 0, 1, 1, 0, 1, 0]) + result = lut.pack_indices(indices, bitwidth=1) + # 0 1 0 1 1 0 1 0 = 0x5A + self.assertEqual(result, bytes([0x5A])) + + def test_multidimensional_flattens(self): + """Multidimensional indices are flattened row-major.""" + # yapf: disable + indices = np.array([[0, 1], + [2, 3]]) + # yapf: enable + result = lut.pack_indices(indices, bitwidth=4) + # 0000 0001 | 0010 0011 = 0x01 0x23 + self.assertEqual(result, bytes([0x01, 0x23])) + + +class TestPackLookupTables(unittest.TestCase): + """Tests for the pack_lookup_tables function.""" + + def test_single_table_int8(self): + """Pack single INT8 lookup table.""" + tables = [np.array([10, 20, 30], dtype=np.int8)] + result = lut.pack_lookup_tables(tables, table_len=4) + # Values: 10, 20, 30, 0 (padding) + self.assertEqual(result, bytes([10, 20, 30, 0])) + + def test_multiple_tables(self): + """Pack multiple lookup tables.""" + tables = [ + np.array([1, 2], dtype=np.int8), + np.array([3, 4], dtype=np.int8), + ] + result = lut.pack_lookup_tables(tables, table_len=4) + # Table 1: 1, 2, 0, 0 | Table 2: 3, 4, 0, 0 + self.assertEqual(result, bytes([1, 2, 0, 0, 3, 4, 0, 0])) + + def test_int16_little_endian(self): + """INT16 values are packed in native byte order.""" + tables = [np.array([0x1234, 0x5678], dtype=' _IteratorTo: - return self._cls(self._sequence[key], key, self._parent) - - def __len__(self): - return len(self._sequence) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _IndirectIterator(Generic[_IteratorTo]): - - def __init__(self, indices, sequence): - self._indices = indices - self._index = 0 - self._sequence = sequence - - def __getitem__(self, key) -> _IteratorTo: - index = self._indices[key] - return self._sequence[index] - - def __len__(self): - return len(self._indices) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _Operator: - - def __init__(self, operator, index, subgraph): - self.operator = operator - self.index = index - self.subgraph = subgraph - - @property - def opcode(self) -> tflite.OperatorCodeT: - return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] - - @property - def inputs(self): - return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) - - -_NP_DTYPES = { - tflite.TensorType.FLOAT16: np.dtype(" _Buffer: - return self.subgraph.model.buffers[self._tensor_t.buffer] - - @property - def data(self) -> bytes: - return self.buffer.data - - @property - def dtype(self) -> np.dtype: - return _NP_DTYPES[self._tensor_t.type] - - @property - def array(self) -> np.ndarray: - """Returns an array created from the Tensor's data, type, and shape. - - Note the bytes in the data buffer and the Tensor's type and shape may be - inconsistent, and thus the returned array invalid, if the data buffer has - been altered according to the compression schema, in which the data buffer - is an array of fixed-width, integer fields. - """ - return np.frombuffer(self.data, - dtype=self.dtype).reshape(self._tensor_t.shape) - - @property - def quantization(self) -> tflite.QuantizationParametersT | None: - return self._tensor_t.quantization - - -class _Buffer: - - def __init__(self, buffer_t: tflite.BufferT, index, model): - self._buffer_t = buffer_t - self.index = index - self.model = model - - @property - def data(self) -> bytes: - return bytes(self._buffer_t.data) - - @data.setter - def data(self, value: ByteString): - self._buffer_t.data = list(value) - - def extend(self, values: NDArray): - self._buffer_t.data.extend(values.tobytes()) - - -class _Subgraph: - - def __init__(self, subgraph_t: tflite.SubGraphT, index: int, model: _Model): - self._subgraph_t = subgraph_t - self.index = index - self.model = model - - @property - def operators(self) -> _Iterator[_Operator]: - return _Iterator(self._subgraph_t.operators, _Operator, parent=self) - - @property - def tensors(self) -> _Iterator[_Tensor]: - return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) - - -class _Model: - """A facade for manipulating tflite.Model. - """ - - def __init__(self, model_t: tflite.ModelT): - self._model_t = model_t - - def compile(self) -> bytearray: - """Returns a tflite.Model flatbuffer. - """ - size_hint = 4 * 2**10 - builder = flatbuffers.Builder(size_hint) - builder.Finish(self._model_t.Pack(builder)) - return builder.Output() - - def add_buffer(self) -> _Buffer: - """Adds a buffer to the model. - """ - buffer = tflite.BufferT() - buffer.data = [] - self._model_t.buffers.append(buffer) - index = len(self._model_t.buffers) - 1 - return _Buffer(buffer, index, self._model_t) - - def add_metadata(self, key, value): - """Adds a key-value pair, writing value to a newly created buffer. - """ - metadata = tflite.MetadataT() - metadata.name = key - buffer = self.add_buffer() - buffer.data = value - metadata.buffer = buffer.index - self._model_t.metadata.append(metadata) - - @property - def metadata(self) -> dict[str, _Buffer]: - """Returns the model's metadata as a dictionary to Buffer objects. - """ - result = {} - for m in self._model_t.metadata: - name = m.name.decode("utf-8") # type: ignore (fb library is wrong) - buffer = _Buffer(self._model_t.buffers[m.buffer], m.buffer, - self._model_t) - result[name] = buffer - - return result - - @property - def operatorCodes(self): - return self._model_t.operatorCodes - - @property - def subgraphs(self) -> _Iterator[_Subgraph]: - return _Iterator(self._model_t.subgraphs, _Subgraph, parent=self) - - @property - def buffers(self) -> _Iterator[_Buffer]: - return _Iterator(self._model_t.buffers, _Buffer, parent=self) - - -def read(buffer: ByteString): - """Reads a tflite.Model and returns a model facade. - """ - schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) - return _Model(schema_model) diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py deleted file mode 100644 index 87e71fa968b..00000000000 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import unittest -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from tflite_micro.tensorflow.lite.micro.compression import model_facade -from tflite_micro.tensorflow.lite.micro.compression import test_models - -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - 1: { - "name": "metadata1", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, # ADD - "inputs": ( - 1, - 2, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, # FULLY_CONNECTED - "inputs": ( - 3, - 4, - 5, - ), - "outputs": (6, ), - }, - }, - "tensors": { - 0: { - "name": "tensor0", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "name": "tensor1", - "shape": (8, 1), - "type": tflite.TensorType.INT16, - "buffer": 2, - }, - 2: { - "name": "tensor2", - "shape": (4, 1), - "type": tflite.TensorType.INT32, - "buffer": 3, - }, - 3: { - "name": "tensor3", - "shape": (2, 1), - "type": tflite.TensorType.INT64, - "buffer": 4, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.INT64: np.int64, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.UINT16: np.uint16, + tflite.TensorType.UINT32: np.uint32, + tflite.TensorType.FLOAT16: np.float16, + tflite.TensorType.FLOAT32: np.float32, + tflite.TensorType.FLOAT64: np.float64, + tflite.TensorType.BOOL: np.bool_, + } + return type_map.get(dtype, np.uint8) + + +class ProprietaryModelTest(unittest.TestCase): + """Integration tests using proprietary models.""" + + # Parsed from command line in main() + models_dir = None + + @classmethod + def setUpClass(cls): + if not cls.models_dir: + raise unittest.SkipTest( + "No models directory provided. " + "Usage: bazel test ... --test_arg=/path/to/models") + + cls.model_paths = sorted( + glob.glob(os.path.join(cls.models_dir, '*.tflite'))) + if not cls.model_paths: + raise unittest.SkipTest(f"No .tflite files found in {cls.models_dir}") + + def test_all_models(self): + """Run compression test on each discovered model.""" + for model_path in self.model_paths: + with self.subTest(model=os.path.basename(model_path)): + self._test_model_compression(model_path) + + def _test_model_compression(self, model_path): + """Test that a compressed model produces same outputs as original.""" + with open(model_path, 'rb') as f: + flatbuffer = f.read() + + # Load compression spec from sidecar file + specs = self._load_compression_spec(model_path) + + # Load tolerance config + rtol, atol = self._load_tolerance(model_path) + + # Compress the model + compressed_fb = compress.compress(flatbuffer, specs) + + # Create interpreters + original_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Generate random inputs and compare outputs + np.random.seed(42) + model = model_editor.read(flatbuffer) + sg = model.subgraphs[0] + + for trial in range(5): + # Set inputs + for i, input_tensor in enumerate(sg.inputs): + test_input = self._generate_input(input_tensor) + original_interp.set_input(test_input, i) + compressed_interp.set_input(test_input, i) + + # Run inference + original_interp.invoke() + compressed_interp.invoke() + + # Compare outputs + for i in range(len(sg.outputs)): + expected = original_interp.get_output(i) + actual = compressed_interp.get_output(i) + self._compare_outputs(expected, actual, rtol, atol, + f"trial {trial}, output {i}") + + def _generate_input(self, tensor): + """Generate random input respecting tensor dtype.""" + shape = tensor.shape + dtype = _dtype_to_numpy(tensor.dtype) + + if np.issubdtype(dtype, np.floating): + return np.random.uniform(-1.0, 1.0, shape).astype(dtype) + elif np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + return np.random.randint(info.min, info.max + 1, shape, dtype=dtype) + elif dtype == np.bool_: + return np.random.choice([False, True], shape) + return np.zeros(shape, dtype=dtype) + + def _load_compression_spec(self, model_path): + """Load compression spec from sidecar YAML file. + + Raises: + FileNotFoundError: If no spec file is found. + """ + spec_path = model_path.replace('.tflite', '.spec.yaml') + if os.path.exists(spec_path): + with open(spec_path) as f: + return spec.parse_yaml(f.read()) + + raise FileNotFoundError( + f"No compression spec file found for {model_path}. " + f"Expected: {spec_path}") + + def _load_tolerance(self, model_path): + """Load tolerance from sidecar config if present. + + Returns (0, 0) for exact match if no config file exists. + """ + config_path = model_path.replace('.tflite', '.config.json') + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + return config.get('rtol', 0), config.get('atol', 0) + return 0, 0 + + def _compare_outputs(self, expected, actual, rtol, atol, context=""): + """Compare outputs with optional tolerance.""" + msg = f"Output mismatch ({context})" if context else "Output mismatch" + if rtol == 0 and atol == 0: + np.testing.assert_array_equal(expected, actual, err_msg=msg) + else: + np.testing.assert_allclose(expected, + actual, + rtol=rtol, + atol=atol, + err_msg=msg) + + +if __name__ == "__main__": + # Suppress TF C++ info/debug logs (0=DEBUG, 1=INFO, 2=WARNING, 3=ERROR) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + # Disable oneDNN to avoid non-deterministic floating point results + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + + # Parse models directory from args, then strip it so tf.test doesn't see it + for arg in sys.argv[1:]: + if not arg.startswith('-') and os.path.isdir(arg): + ProprietaryModelTest.models_dir = arg + sys.argv.remove(arg) + break + + unittest.main() diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py new file mode 100644 index 00000000000..5c95e3e87e9 --- /dev/null +++ b/tensorflow/lite/micro/compression/pruning.py @@ -0,0 +1,59 @@ +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pruning compression plugin (stub). + +This module provides a placeholder for pruning (sparsity) compression. +The actual implementation is not yet available. + +Supported tensor types (when implemented): All TFLM tensor types +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class PruningCompressor(compressor.Compressor): + """Pruning compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual pruning + compression algorithm for sparse tensors is not yet implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.PRUNING.""" + return decode.DecodeType.PRUNING + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using pruning (sparsity) encoding. + + Args: + tensor: The tensor to compress. + method: Must be a PruningCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Pruning compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index 6f782e92d7a..5c0f81885bc 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -58,10 +58,32 @@ class Tensor: @dataclass class LookUpTableCompression(CompressionMethod): + """LUT compression using lookup tables. + Attributes: + index_bitwidth: Number of bits per index (1-7). + """ index_bitwidth: int +@dataclass +class HuffmanCompression(CompressionMethod): + """Huffman compression using Xtensa-format decode tables. + + Supported tensor types: INT8, INT16 only. + """ + pass + + +@dataclass +class PruningCompression(CompressionMethod): + """Pruning (sparsity) compression. + + Supported tensor types: All TFLM tensor types. + """ + pass + + class ParseError(Exception): "Raised when the spec string cannot be parsed." @@ -70,6 +92,18 @@ def __init__(self, message="error parsing spec", wrapped_exception=None): self.original_exception = wrapped_exception +def _parse_compression_method(comp: dict) -> CompressionMethod: + """Parse a single compression method from YAML dict.""" + if "lut" in comp: + return LookUpTableCompression(index_bitwidth=comp["lut"]["index_bitwidth"]) + elif "huffman" in comp: + return HuffmanCompression() + elif "pruning" in comp: + return PruningCompression() + else: + raise ParseError(f"Unknown compression method: {list(comp.keys())}") + + def parse_yaml(y: str) -> list[Tensor]: "Parses a compression spec in a YAML string into its Python representation." try: @@ -77,14 +111,19 @@ def parse_yaml(y: str) -> list[Tensor]: tensors = [] for item in config["tensors"]: - bitwidth = item["compression"][0]["lut"]["index_bitwidth"] - tensor = Tensor(subgraph=item["subgraph"], - tensor=item["tensor"], - compression=[ - LookUpTableCompression(index_bitwidth=bitwidth), - ]) + methods = [] + for comp in item["compression"]: + methods.append(_parse_compression_method(comp)) + + tensor = Tensor( + subgraph=item["subgraph"], + tensor=item["tensor"], + compression=methods, + ) tensors.append(tensor) + except ParseError: + raise except Exception as e: raise ParseError() from e diff --git a/tensorflow/lite/micro/compression/test_models.py b/tensorflow/lite/micro/compression/test_models.py deleted file mode 100644 index 80286d17359..00000000000 --- a/tensorflow/lite/micro/compression/test_models.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Tools for constructing flatbuffers for testing. - -This module provides tools for constructing .tflite flatbuffers from a Python -dictionary representation of a model, a prototype of which can be found in -EXAMPLE_MODEL. - -Example usage: - model_definition = {...} # use EXAMPLE_MODEL as prototype - flatbuffer: bytearray = test_models.build(model_definition) -""" - -# This module must remain low-level and independent from any helpers in this -# project which make constructing model and flatbuffers easier, because this -# module is used to define tests for those helpers. - -import flatbuffers -import numpy as np -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - -EXAMPLE_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, - "inputs": ( - 0, - 1, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, - "inputs": ( - 3, - 2, - ), - "outputs": (4, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 0, - }, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" bytearray: - """Builds a .tflite flatbuffer from a model definition. - - Args: - model_definition: A dictionary representation of the model, a prototype of - which can be found in the EXAMPLE_MODEL attribute of this module. - - Returns: - A tflite flatbuffer. - """ - root = tflite.ModelT() - description = model_definition.get("description") - if description is not None: - root.description = description - - root.operatorCodes = [] - for id, operator_code in model_definition["operator_codes"].items(): - assert id == len(root.operatorCodes) - opcode_t = tflite.OperatorCodeT() - root.operatorCodes.append(opcode_t) - opcode_t.builtinCode = operator_code["builtin_code"] - - root.metadata = [] - if "metadata" in model_definition: - for _, metadata in model_definition["metadata"].items(): - metadata_t = tflite.MetadataT() - metadata_t.name = metadata["name"] - metadata_t.buffer = metadata["buffer"] - root.metadata.append(metadata_t) - - root.subgraphs = [] - for id, subgraph in model_definition["subgraphs"].items(): - assert id == len(root.subgraphs) - subgraph_t = tflite.SubGraphT() - root.subgraphs.append(subgraph_t) - - subgraph_t.operators = [] - for id, operator in subgraph["operators"].items(): - assert id == len(subgraph_t.operators) - operator_t = tflite.OperatorT() - operator_t.opcodeIndex = operator["opcode_index"] - operator_t.inputs = operator["inputs"] - operator_t.outputs = operator["outputs"] - subgraph_t.operators.append(operator_t) - - subgraph_t.tensors = [] - for id, tensor in subgraph["tensors"].items(): - assert id == len(subgraph_t.tensors) - tensor_t = tflite.TensorT() - tensor_t.name = tensor.get("name", None) - tensor_t.shape = tensor["shape"] - tensor_t.type = tensor["type"] - tensor_t.buffer = tensor["buffer"] - - if "quantization" in tensor: - tensor_t.quantization = tflite.QuantizationParametersT() - tensor_t.quantization.quantizedDimension = \ - tensor["quantization"].get("quantized_dimension", None) - tensor_t.quantization.scale = \ - tensor["quantization"].get("scale", None) - tensor_t.quantization.zeroPoint = \ - tensor["quantization"].get("zero_point", None) - - subgraph_t.tensors.append(tensor_t) - - root.buffers = [] - for id, data in model_definition["buffers"].items(): - assert id == len(root.buffers) - buffer_t = tflite.BufferT() - - if data is None: - buffer_t.data = [] - elif isinstance(data, np.ndarray): - array = data.astype(data.dtype.newbyteorder("<")) # ensure little-endian - buffer_t.data = list(array.tobytes()) - else: - raise TypeError(f"buffer_id {id} must be None or an np.ndarray") - - root.buffers.append(buffer_t) - - size_hint = 1 * 2**20 - builder = flatbuffers.Builder(size_hint) - builder.Finish(root.Pack(builder)) - flatbuffer = builder.Output() - return flatbuffer diff --git a/tensorflow/lite/micro/compression/test_models_test.py b/tensorflow/lite/micro/compression/test_models_test.py deleted file mode 100644 index d7e961c2dd9..00000000000 --- a/tensorflow/lite/micro/compression/test_models_test.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from tflite_micro.tensorflow.lite.micro.compression import test_models -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - - -class TestBuild(unittest.TestCase): - - def setUp(self): - self.flatbuffer = test_models.build(test_models.EXAMPLE_MODEL) - - def testNotDegenerate(self): - model = tflite.ModelT.InitFromPackedBuf(self.flatbuffer, 0) - self.assertEqual(model.operatorCodes[0].builtinCode, - tflite.BuiltinOperator.FULLY_CONNECTED) - - -if __name__ == "__main__": - unittest.main()