|
| 1 | +# Copyright 2026 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""DECODE compression module.""" |
| 15 | + |
| 16 | +# Implements the DECODE operator compression scheme described in the |
| 17 | +# "TFLM DECODE Operator Design" document, revised May 20, 2025. |
| 18 | +# |
| 19 | +# The DECODE operator transforms an encoded tensor, alongside a paired |
| 20 | +# ancillary data tensor, into a tensor ready for use as input to any |
| 21 | +# operator. For example, an encoded tensor might contain compressed |
| 22 | +# data, while the paired ancillary data tensor holds the information |
| 23 | +# necessary for decompression. The DECODE operator's output is a fully |
| 24 | +# decompressed tensor. |
| 25 | +# |
| 26 | +# DECODE operators are inserted into the TfLite model subgraph |
| 27 | +# immediately before each operation that uses a decodable tensor as |
| 28 | +# input. |
| 29 | +# |
| 30 | +# Ancillary Data Tensor |
| 31 | +# |
| 32 | +# The ancillary data tensor contains the information necessary for |
| 33 | +# decoding. It begins with a 16-byte DECODE Common Metadata (DCM) |
| 34 | +# header, followed by decode-type-specific ancillary data. |
| 35 | +# |
| 36 | +# DECODE Common Metadata (DCM) |
| 37 | +# |
| 38 | +# Byte 0: Decode type |
| 39 | +# 0-127: TFLM-supported decode operations (see below) |
| 40 | +# 128-255: Custom operations requiring application-registered |
| 41 | +# handlers |
| 42 | +# |
| 43 | +# Supported decode types: |
| 44 | +# |
| 45 | +# 0: LUT decompression |
| 46 | +# All TFLM tensor types supported in reference and optimized |
| 47 | +# code. |
| 48 | +# |
| 49 | +# 1: Huffman decompression using Xtensa format decode tables |
| 50 | +# INT8 and INT16 tensor types only, in reference and optimized |
| 51 | +# code. |
| 52 | +# |
| 53 | +# 2: Pruning decompression |
| 54 | +# All TFLM tensor types supported in reference and optimized |
| 55 | +# code. |
| 56 | +# |
| 57 | +# 3-127: Reserved |
| 58 | +# |
| 59 | +# 128-255: Custom decode types |
| 60 | +# Requires user-supplied encoding module and decoding ancillary |
| 61 | +# data. |
| 62 | +# |
| 63 | +# Byte 1: DCM version (currently 1) |
| 64 | +# |
| 65 | +# Bytes 2-3: Reserved |
| 66 | +# |
| 67 | +# Bytes 4-15: User-defined |
| 68 | +# Used by TFLM decode types to avoid requiring additional alignment |
| 69 | +# of metadata or ancillary data. |
| 70 | +# |
| 71 | +# The 16-byte DCM size ensures that subsequent metadata and ancillary |
| 72 | +# data are 128-bit aligned, which is required for some optimized |
| 73 | +# decoding operations such as Xtensa LUT decompression. |
| 74 | +# |
| 75 | +# For TFLM decode types, ancillary data starts immediately after the |
| 76 | +# DCM. For custom decode types, the location is determined by |
| 77 | +# user-defined metadata. |
| 78 | + |
| 79 | +from dataclasses import dataclass |
| 80 | +from typing import Protocol |
| 81 | + |
| 82 | + |
| 83 | +class DecodeType: |
| 84 | + """Decode operation type (0-255). |
| 85 | +
|
| 86 | + Use predefined constants for built-in types or DecodeType.custom() |
| 87 | + for custom types: |
| 88 | + DecodeType.LUT # 0 |
| 89 | + DecodeType.HUFFMAN # 1 |
| 90 | + DecodeType.PRUNING # 2 |
| 91 | + DecodeType.custom(200) # Custom type 128-255 |
| 92 | + """ |
| 93 | + |
| 94 | + # Built-in decode types (class variables set after class definition) |
| 95 | + LUT: 'DecodeType' |
| 96 | + HUFFMAN: 'DecodeType' |
| 97 | + PRUNING: 'DecodeType' |
| 98 | + |
| 99 | + def __init__(self, code: int, name: str = None): |
| 100 | + """Initialize DecodeType. |
| 101 | +
|
| 102 | + Args: |
| 103 | + code: Integer code 0-255 |
| 104 | + name: Optional name for the type. If not provided: |
| 105 | + - Codes 0-127: Named "TYPE_{code}" |
| 106 | + - Codes 128-255: Named "CUSTOM_{code}" |
| 107 | + """ |
| 108 | + if not 0 <= code <= 255: |
| 109 | + raise ValueError(f"Decode type must be 0-255, got {code}") |
| 110 | + self.code = code |
| 111 | + |
| 112 | + # Auto-generate name if not provided |
| 113 | + if name is None: |
| 114 | + self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}" |
| 115 | + else: |
| 116 | + self.name = name |
| 117 | + |
| 118 | + self._is_custom = code >= 128 |
| 119 | + |
| 120 | + @property |
| 121 | + def is_custom(self) -> bool: |
| 122 | + """True if this is a custom decode type (128-255).""" |
| 123 | + return self._is_custom |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def custom(cls, code: int) -> 'DecodeType': |
| 127 | + """Create custom decode type (128-255). |
| 128 | +
|
| 129 | + Args: |
| 130 | + code: Integer code 128-255 |
| 131 | +
|
| 132 | + Returns: |
| 133 | + DecodeType with name CUSTOM_{code} |
| 134 | + """ |
| 135 | + if not 128 <= code <= 255: |
| 136 | + raise ValueError(f"Custom decode type must be 128-255, got {code}") |
| 137 | + return cls(code) |
| 138 | + |
| 139 | + def __int__(self): |
| 140 | + """Convert to integer for serialization.""" |
| 141 | + return self.code |
| 142 | + |
| 143 | + def __eq__(self, other): |
| 144 | + if isinstance(other, DecodeType): |
| 145 | + return self.code == other.code |
| 146 | + return self.code == other |
| 147 | + |
| 148 | + def __repr__(self): |
| 149 | + return f"DecodeType.{self.name}({self.code})" |
| 150 | + |
| 151 | + |
| 152 | +# Define built-in decode type constants |
| 153 | +DecodeType.LUT = DecodeType(0, "LUT") |
| 154 | +DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN") |
| 155 | +DecodeType.PRUNING = DecodeType(2, "PRUNING") |
| 156 | + |
| 157 | + |
| 158 | +@dataclass |
| 159 | +class DecodeCommonMetadata: |
| 160 | + """16-byte DECODE Common Metadata (DCM) header. |
| 161 | +
|
| 162 | + Attributes: |
| 163 | + decode_type: Decode operation type. Use DecodeType constants or |
| 164 | + DecodeType.custom(code) for custom types. |
| 165 | + version: DCM version (currently 1). |
| 166 | + user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM |
| 167 | + decode types to avoid requiring additional alignment of metadata |
| 168 | + or ancillary data. |
| 169 | + """ |
| 170 | + decode_type: DecodeType |
| 171 | + version: int = 1 |
| 172 | + user_data: bytes = b'\x00' * 12 |
| 173 | + |
| 174 | + def to_bytes(self) -> bytes: |
| 175 | + """Serialize DCM to 16-byte sequence.""" |
| 176 | + decode_code = int(self.decode_type) |
| 177 | + if len(self.user_data) < 12: |
| 178 | + # Pad with zeros if user_data is too short |
| 179 | + user_data = self.user_data + b'\x00' * (12 - len(self.user_data)) |
| 180 | + else: |
| 181 | + user_data = self.user_data[:12] |
| 182 | + |
| 183 | + result = bytearray(16) |
| 184 | + result[0] = decode_code |
| 185 | + result[1] = self.version |
| 186 | + # bytes 2-3 remain zero (reserved) |
| 187 | + result[4:16] = user_data |
| 188 | + return bytes(result) |
| 189 | + |
| 190 | + |
| 191 | +class AncillaryDataSerializer(Protocol): |
| 192 | + """Protocol for objects that can serialize ancillary data.""" |
| 193 | + |
| 194 | + def to_bytes(self) -> bytes: |
| 195 | + ... |
| 196 | + |
| 197 | + |
| 198 | +@dataclass |
| 199 | +class AncillaryDataTensor: |
| 200 | + """Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data. |
| 201 | +
|
| 202 | + The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte |
| 203 | + DCM header, followed by decode-type-specific ancillary data. |
| 204 | +
|
| 205 | + Attributes: |
| 206 | + dcm: The DECODE Common Metadata header. |
| 207 | + ancillary_data: The decode-type-specific ancillary data, either as raw bytes |
| 208 | + or as an object implementing the AncillaryDataSerializer |
| 209 | + protocol. May be None if only the DCM is needed. |
| 210 | + """ |
| 211 | + dcm: DecodeCommonMetadata |
| 212 | + ancillary_data: AncillaryDataSerializer | bytes | None = None |
| 213 | + |
| 214 | + def with_ancillary_data( |
| 215 | + self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor': |
| 216 | + """Create new ADT with ancillary data added. |
| 217 | +
|
| 218 | + Args: |
| 219 | + data: Ancillary data to add, either as raw bytes or as an object |
| 220 | + implementing AncillaryDataSerializer. |
| 221 | +
|
| 222 | + Returns: |
| 223 | + New AncillaryDataTensor with the specified ancillary data. |
| 224 | + """ |
| 225 | + return AncillaryDataTensor(self.dcm, data) |
| 226 | + |
| 227 | + def to_bytes(self) -> bytes: |
| 228 | + """Serialize entire ADT to bytes. |
| 229 | +
|
| 230 | + Returns: |
| 231 | + Byte sequence containing DCM followed by ancillary data (if present). |
| 232 | + """ |
| 233 | + dcm_bytes = self.dcm.to_bytes() |
| 234 | + if self.ancillary_data is None: |
| 235 | + return dcm_bytes |
| 236 | + if isinstance(self.ancillary_data, bytes): |
| 237 | + return dcm_bytes + self.ancillary_data |
| 238 | + return dcm_bytes + self.ancillary_data.to_bytes() |
0 commit comments