-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathruntime.py
More file actions
328 lines (265 loc) · 11.5 KB
/
Copy pathruntime.py
File metadata and controls
328 lines (265 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python package for TFLM Python Interpreter"""
import enum
import os
from tflite_micro.python.tflite_micro import _runtime
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
def convert_bytearray_to_object(model_bytearray):
"""Converts a tflite model from a bytearray to an object for parsing."""
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
return schema_fb.ModelT.InitFromObj(model_object)
def get_builtin_code_from_operator_code(opcode):
"""Return the builtin code of the given operator code."""
if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
def count_resource_variables(model):
"""Calculates the number of unique resource variables in a model."""
if not isinstance(model, schema_fb.ModelT):
model = convert_bytearray_to_object(model)
unique_shared_names = set()
for subgraph in model.subgraphs:
if subgraph.operators is None:
continue
for op in subgraph.operators:
builtin_code = get_builtin_code_from_operator_code(
model.operatorCodes[op.opcodeIndex])
if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
unique_shared_names.add(op.builtinOptions.sharedName)
return len(unique_shared_names)
class InterpreterConfig(enum.Enum):
"""There are two mutually exclusive types of way you could use the TFLM python
interpreter, this enum is made so that users can clearly choose between the
two
different usage method for the interpreter.
The first default way is kRecordingAllocation where all memory usage by the
interpreter is recorded on inference. When using this config the GetTensor()
api is disabled by the interpreter since this interpreter configuration
doesn’t
guarantee that the valid data for all tensors is available post inference.
The second way is kPreserveAllTensors where the GetTensor() api is disabled by
the interpreter since this interpreter configuration doesn’t guarantee that
the
valid data for all tensors is available post inference. But the memory usage
by
the interpreter won’t be recorded on inference.
Usage:
default_interpreter = Interpreter(…
intrepreter_config=InterpreterConfig.kAllocationRecording)
preserve_interpreter = Interpreter(…
intrepreter_config=InterpreterConfig.kPreserveAllTensors)
"""
kAllocationRecording = 0
kPreserveAllTensors = 1
#TODO(b/297118768): Once Korko Docker container for ubuntu x86 has imutabledict
# added to it, this should be turned into an immutabledict.
_ENUM_TRANSLATOR = {
InterpreterConfig.kAllocationRecording:
(_runtime.PythonInterpreterConfig.kAllocationRecording),
InterpreterConfig.kPreserveAllTensors:
(_runtime.PythonInterpreterConfig.kPreserveAllTensors),
}
class Interpreter(object):
def __init__(
self,
model_data,
custom_op_registerers,
arena_size,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
if model_data is None:
raise ValueError("Model must not be None")
if not isinstance(custom_op_registerers, list) or not all(
isinstance(s, str) for s in custom_op_registerers):
raise ValueError("Custom ops registerers must be a list of strings")
# This is a heuristic to ensure that the arena is sufficiently sized.
if arena_size is None:
arena_size = len(model_data) * 10
# Some models make use of resource variables ops, get the count here
num_resource_variables = count_resource_variables(model_data)
print("Number of resource variables the model uses = ",
num_resource_variables)
self._interpreter = _runtime.InterpreterWrapper(
model_data,
custom_op_registerers,
arena_size,
num_resource_variables,
_ENUM_TRANSLATOR[intrepreter_config],
alt_decompression_memory_size,
)
@classmethod
def from_file(
self,
model_path,
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model .tflite filepath.
Args:
model_path: Filepath to the .tflite model
custom_op_registerers: List of strings, each of which is the name of a
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.
Returns:
An Interpreter instance
"""
if model_path is None or not os.path.isfile(model_path):
raise ValueError("Invalid model file path")
with open(model_path, "rb") as f:
model_data = f.read()
return Interpreter(
model_data,
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)
@classmethod
def from_bytes(
self,
model_data,
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model in byte array.
Args:
model_data: Model in byte array format
custom_op_registerers: List of strings, each of which is the name of a
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.
Returns:
An Interpreter instance
"""
return Interpreter(
model_data,
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)
def print_allocations(self):
"""Invoke the RecordingMicroAllocator to print the arena usage.
This should be called after `invoke()`.
Returns:
This method does not return anything, but It dumps the arena
usage to stderr.
"""
self._interpreter.PrintAllocations()
def invoke(self):
"""Invoke the TFLM interpreter to run an inference.
This should be called after `set_input()`.
Returns:
Status code of the C++ invoke function. A RuntimeError will be raised as
well upon any error.
"""
return self._interpreter.Invoke()
def reset(self):
"""Reset the model state to be what you would expect when the interpreter is first
created. i.e. after Init and Prepare is called for the very first time.
This should be called after invoke stateful model like LSTM.
Returns:
Status code of the C++ invoke function. A RuntimeError will be raised as
well upon any error.
"""
return self._interpreter.Reset()
def set_input(self, input_data, index):
"""Set input data into input tensor.
This should be called before `invoke()`.
Args:
input_data: Input data in numpy array format. The numpy array format is
chosen to be consistent with TFLite interpreter.
index: An integer between 0 and the number of input tensors (exclusive)
consistent with the order defined in the list of inputs in the .tflite
model
"""
if input_data is None:
raise ValueError("Input data must not be None")
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
self._interpreter.SetInputTensor(input_data, index)
def get_output(self, index):
"""Get data from output tensor.
The output data correspond to the most recent `invoke()`.
Args:
index: An integer between 0 and the number of output tensors (exclusive)
consistent with the order defined in the list of outputs in the .tflite
model
Returns:
Output data in numpy array format. The numpy array format is chosen to
be consistent with TFLite interpreter.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetOutputTensor(index)
def GetTensor(self, tensor_index, subgraph_index):
return self._interpreter.GetTensor(tensor_index, subgraph_index)
def get_input_details(self, index):
"""Get input tensor information
Args:
index (int): An integer between 0 and the number of output tensors
(exclusive) consistent with the order defined in the list of outputs
in the .tflite model
Returns:
A dictionary from input index to tensor details where each item is a
dictionary with details about an input tensor. Each dictionary contains
the following fields that describe the tensor:
+ `shape`: The shape of the tensor.
+ `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
+ `quantization_parameters`: A dictionary of parameters used to quantize
the tensor:
~ `scales`: List of scales (one if per-tensor quantization).
~ `zero_points`: List of zero_points (one if per-tensor quantization).
~ `quantized_dimension`: Specifies the dimension of per-axis
quantization, in the case of multiple scales/zero_points.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetInputTensorDetails(index)
def get_output_details(self, index):
"""Get output tensor information
Args:
index (int): An integer between 0 and the number of output tensors
(exclusive) consistent with the order defined in the list of outputs
in the .tflite model
Returns:
A dictionary from input index to tensor details where each item is a
dictionary with details about an input tensor. Each dictionary contains
the following fields that describe the tensor:
+ `shape`: The shape of the tensor.
+ `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
+ `quantization_parameters`: A dictionary of parameters used to quantize
the tensor:
~ `scales`: List of scales (one if per-tensor quantization).
~ `zero_points`: List of zero_points (one if per-tensor quantization).
~ `quantized_dimension`: Specifies the dimension of per-axis
quantization, in the case of multiple scales/zero_points.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetOutputTensorDetails(index)