Skip to content

Commit 074b75f

Browse files
authored
feat(compression): add Compressor protocol (#3590)
Define the plugin interface for compression methods. Each compressor implements the Compressor protocol with a compress() method that returns encoded data and ancillary data. BUG=part of #3256
1 parent 7ca66d1 commit 074b75f

2 files changed

Lines changed: 90 additions & 0 deletions

File tree

tensorflow/lite/micro/compression/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,16 @@ tflm_py_test(
261261
],
262262
)
263263

264+
tflm_py_library(
265+
name = "compressor",
266+
srcs = ["compressor.py"],
267+
deps = [
268+
":decode",
269+
":model_editor",
270+
":spec",
271+
],
272+
)
273+
264274
tflm_py_binary(
265275
name = "view",
266276
srcs = [
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Compression plugin interface."""
15+
16+
from dataclasses import dataclass
17+
from typing import Protocol
18+
19+
from tflite_micro.tensorflow.lite.micro.compression import decode
20+
from tflite_micro.tensorflow.lite.micro.compression import model_editor
21+
from tflite_micro.tensorflow.lite.micro.compression import spec
22+
23+
24+
class CompressionError(Exception):
25+
"""Raised when compression fails for the reason documented in the message."""
26+
27+
def __init__(self, message, wrapped_exception=None):
28+
if wrapped_exception:
29+
super().__init__(f"{message}: {str(wrapped_exception)}")
30+
else:
31+
super().__init__(message)
32+
self.original_exception = wrapped_exception
33+
34+
35+
@dataclass
36+
class CompressionResult:
37+
"""Result of compressing a tensor.
38+
39+
Attributes:
40+
encoded_data: The compressed tensor data (e.g., packed indices for LUT).
41+
ancillary_data: The complete ancillary data tensor bytes (DCM + type-specific
42+
data). This is the full buffer contents for the ancillary
43+
tensor.
44+
"""
45+
encoded_data: bytes
46+
ancillary_data: bytes
47+
48+
49+
class Compressor(Protocol):
50+
"""Protocol that compression plugins must implement.
51+
52+
Each compression method (LUT, Huffman, Pruning) provides a class implementing
53+
this protocol. The compress() function uses duck typing to call the plugin.
54+
"""
55+
56+
@property
57+
def decode_type(self) -> decode.DecodeType:
58+
"""The DecodeType constant for this compression method."""
59+
...
60+
61+
def compress(
62+
self,
63+
tensor: model_editor.Tensor,
64+
method: spec.CompressionMethod,
65+
) -> CompressionResult:
66+
"""Compress a tensor according to the specified method.
67+
68+
Args:
69+
tensor: The tensor to compress. Must have data (tensor.array is not None)
70+
and quantization parameters for axis inference.
71+
method: The compression method spec (e.g., LookUpTableCompression).
72+
73+
Returns:
74+
CompressionResult with encoded tensor data and ancillary data bytes.
75+
76+
Raises:
77+
CompressionError: If compression fails (e.g., too many unique values
78+
for specified bitwidth, missing quantization, etc.).
79+
"""
80+
...

0 commit comments

Comments
 (0)