Skip to content

Commit e0b2c28

Browse files
authored
feat(compression): add DECODE operator types and metadata (#3589)
Add decode module with DecodeType constants and DecodeCommonMetadata, per the TFLM DECODE Operator Design document. BUG=part of #3256
1 parent eea46d3 commit e0b2c28

3 files changed

Lines changed: 415 additions & 0 deletions

File tree

tensorflow/lite/micro/compression/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,20 @@ tflm_py_test(
247247
],
248248
)
249249

250+
tflm_py_library(
251+
name = "decode",
252+
srcs = ["decode.py"],
253+
)
254+
255+
tflm_py_test(
256+
name = "decode_test",
257+
size = "small",
258+
srcs = ["decode_test.py"],
259+
deps = [
260+
":decode",
261+
],
262+
)
263+
250264
tflm_py_binary(
251265
name = "view",
252266
srcs = [
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""DECODE compression module."""
15+
16+
# Implements the DECODE operator compression scheme described in the
17+
# "TFLM DECODE Operator Design" document, revised May 20, 2025.
18+
#
19+
# The DECODE operator transforms an encoded tensor, alongside a paired
20+
# ancillary data tensor, into a tensor ready for use as input to any
21+
# operator. For example, an encoded tensor might contain compressed
22+
# data, while the paired ancillary data tensor holds the information
23+
# necessary for decompression. The DECODE operator's output is a fully
24+
# decompressed tensor.
25+
#
26+
# DECODE operators are inserted into the TfLite model subgraph
27+
# immediately before each operation that uses a decodable tensor as
28+
# input.
29+
#
30+
# Ancillary Data Tensor
31+
#
32+
# The ancillary data tensor contains the information necessary for
33+
# decoding. It begins with a 16-byte DECODE Common Metadata (DCM)
34+
# header, followed by decode-type-specific ancillary data.
35+
#
36+
# DECODE Common Metadata (DCM)
37+
#
38+
# Byte 0: Decode type
39+
# 0-127: TFLM-supported decode operations (see below)
40+
# 128-255: Custom operations requiring application-registered
41+
# handlers
42+
#
43+
# Supported decode types:
44+
#
45+
# 0: LUT decompression
46+
# All TFLM tensor types supported in reference and optimized
47+
# code.
48+
#
49+
# 1: Huffman decompression using Xtensa format decode tables
50+
# INT8 and INT16 tensor types only, in reference and optimized
51+
# code.
52+
#
53+
# 2: Pruning decompression
54+
# All TFLM tensor types supported in reference and optimized
55+
# code.
56+
#
57+
# 3-127: Reserved
58+
#
59+
# 128-255: Custom decode types
60+
# Requires user-supplied encoding module and decoding ancillary
61+
# data.
62+
#
63+
# Byte 1: DCM version (currently 1)
64+
#
65+
# Bytes 2-3: Reserved
66+
#
67+
# Bytes 4-15: User-defined
68+
# Used by TFLM decode types to avoid requiring additional alignment
69+
# of metadata or ancillary data.
70+
#
71+
# The 16-byte DCM size ensures that subsequent metadata and ancillary
72+
# data are 128-bit aligned, which is required for some optimized
73+
# decoding operations such as Xtensa LUT decompression.
74+
#
75+
# For TFLM decode types, ancillary data starts immediately after the
76+
# DCM. For custom decode types, the location is determined by
77+
# user-defined metadata.
78+
79+
from dataclasses import dataclass
80+
from typing import Protocol
81+
82+
83+
class DecodeType:
84+
"""Decode operation type (0-255).
85+
86+
Use predefined constants for built-in types or DecodeType.custom()
87+
for custom types:
88+
DecodeType.LUT # 0
89+
DecodeType.HUFFMAN # 1
90+
DecodeType.PRUNING # 2
91+
DecodeType.custom(200) # Custom type 128-255
92+
"""
93+
94+
# Built-in decode types (class variables set after class definition)
95+
LUT: 'DecodeType'
96+
HUFFMAN: 'DecodeType'
97+
PRUNING: 'DecodeType'
98+
99+
def __init__(self, code: int, name: str = None):
100+
"""Initialize DecodeType.
101+
102+
Args:
103+
code: Integer code 0-255
104+
name: Optional name for the type. If not provided:
105+
- Codes 0-127: Named "TYPE_{code}"
106+
- Codes 128-255: Named "CUSTOM_{code}"
107+
"""
108+
if not 0 <= code <= 255:
109+
raise ValueError(f"Decode type must be 0-255, got {code}")
110+
self.code = code
111+
112+
# Auto-generate name if not provided
113+
if name is None:
114+
self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}"
115+
else:
116+
self.name = name
117+
118+
self._is_custom = code >= 128
119+
120+
@property
121+
def is_custom(self) -> bool:
122+
"""True if this is a custom decode type (128-255)."""
123+
return self._is_custom
124+
125+
@classmethod
126+
def custom(cls, code: int) -> 'DecodeType':
127+
"""Create custom decode type (128-255).
128+
129+
Args:
130+
code: Integer code 128-255
131+
132+
Returns:
133+
DecodeType with name CUSTOM_{code}
134+
"""
135+
if not 128 <= code <= 255:
136+
raise ValueError(f"Custom decode type must be 128-255, got {code}")
137+
return cls(code)
138+
139+
def __int__(self):
140+
"""Convert to integer for serialization."""
141+
return self.code
142+
143+
def __eq__(self, other):
144+
if isinstance(other, DecodeType):
145+
return self.code == other.code
146+
return self.code == other
147+
148+
def __repr__(self):
149+
return f"DecodeType.{self.name}({self.code})"
150+
151+
152+
# Define built-in decode type constants
153+
DecodeType.LUT = DecodeType(0, "LUT")
154+
DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN")
155+
DecodeType.PRUNING = DecodeType(2, "PRUNING")
156+
157+
158+
@dataclass
159+
class DecodeCommonMetadata:
160+
"""16-byte DECODE Common Metadata (DCM) header.
161+
162+
Attributes:
163+
decode_type: Decode operation type. Use DecodeType constants or
164+
DecodeType.custom(code) for custom types.
165+
version: DCM version (currently 1).
166+
user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM
167+
decode types to avoid requiring additional alignment of metadata
168+
or ancillary data.
169+
"""
170+
decode_type: DecodeType
171+
version: int = 1
172+
user_data: bytes = b'\x00' * 12
173+
174+
def to_bytes(self) -> bytes:
175+
"""Serialize DCM to 16-byte sequence."""
176+
decode_code = int(self.decode_type)
177+
if len(self.user_data) < 12:
178+
# Pad with zeros if user_data is too short
179+
user_data = self.user_data + b'\x00' * (12 - len(self.user_data))
180+
else:
181+
user_data = self.user_data[:12]
182+
183+
result = bytearray(16)
184+
result[0] = decode_code
185+
result[1] = self.version
186+
# bytes 2-3 remain zero (reserved)
187+
result[4:16] = user_data
188+
return bytes(result)
189+
190+
191+
class AncillaryDataSerializer(Protocol):
192+
"""Protocol for objects that can serialize ancillary data."""
193+
194+
def to_bytes(self) -> bytes:
195+
...
196+
197+
198+
@dataclass
199+
class AncillaryDataTensor:
200+
"""Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data.
201+
202+
The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte
203+
DCM header, followed by decode-type-specific ancillary data.
204+
205+
Attributes:
206+
dcm: The DECODE Common Metadata header.
207+
ancillary_data: The decode-type-specific ancillary data, either as raw bytes
208+
or as an object implementing the AncillaryDataSerializer
209+
protocol. May be None if only the DCM is needed.
210+
"""
211+
dcm: DecodeCommonMetadata
212+
ancillary_data: AncillaryDataSerializer | bytes | None = None
213+
214+
def with_ancillary_data(
215+
self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor':
216+
"""Create new ADT with ancillary data added.
217+
218+
Args:
219+
data: Ancillary data to add, either as raw bytes or as an object
220+
implementing AncillaryDataSerializer.
221+
222+
Returns:
223+
New AncillaryDataTensor with the specified ancillary data.
224+
"""
225+
return AncillaryDataTensor(self.dcm, data)
226+
227+
def to_bytes(self) -> bytes:
228+
"""Serialize entire ADT to bytes.
229+
230+
Returns:
231+
Byte sequence containing DCM followed by ancillary data (if present).
232+
"""
233+
dcm_bytes = self.dcm.to_bytes()
234+
if self.ancillary_data is None:
235+
return dcm_bytes
236+
if isinstance(self.ancillary_data, bytes):
237+
return dcm_bytes + self.ancillary_data
238+
return dcm_bytes + self.ancillary_data.to_bytes()

0 commit comments

Comments
 (0)