diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py b/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py index f532d81c..ff0dac24 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/dequantized_weight_recovery.py @@ -17,7 +17,9 @@ import dataclasses from typing import Any, Optional + import numpy as np + from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor @@ -35,7 +37,7 @@ def _validate_recovered_weights( scale: np.ndarray, tol: float = 1e-4, ): - """Validates if recovered weights (from the quantized values) are close enough to the original ones. + """Validates if requantized weights are close enough to the original ones. Args: original_vals: Original values before quantization. @@ -47,8 +49,9 @@ def _validate_recovered_weights( RuntimeError: If the maximum difference between original and recovered values exceeds the tolerance. """ + recovered_vals = quant_vals * scale - diff = np.abs(recovered_vals - original_vals).flatten() + diff = np.ravel(np.abs(recovered_vals - original_vals)) max_diff = diff.max() if max_diff > tol: raise RuntimeError( @@ -104,7 +107,7 @@ def get_zp_scale_from_dequantized_symmetric_weights( if quantized_dimension is None: # Per-tensor quantization: One scale for the entire tensor. - scales = _get_scale(dequant_vals.flatten(), min_scale) + scales = _get_scale(np.ravel(dequant_vals), min_scale) scales = np.array([[scales]]) else: # Per-channel quantization: A scale for each slice along the dimension. @@ -112,14 +115,12 @@ def get_zp_scale_from_dequantized_symmetric_weights( # number of dimensions as the input, with 1 in all dimensions except for the # quantized dimension, which retains its original size. scales = np.empty( - tuple( - [ - 1 - if i != quantized_dimension - else dequant_vals.shape[quantized_dimension] - for i in range(dequant_vals.ndim) - ] - ) + tuple([ + 1 + if i != quantized_dimension + else dequant_vals.shape[quantized_dimension] + for i in range(dequant_vals.ndim) + ]) ) for i in range(dequant_vals.shape[quantized_dimension]): slices = [slice(None)] * dequant_vals.ndim diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/op_architecture_tests/test_utils.py b/ai_edge_quantizer/algorithms/uniform_quantize/op_architecture_tests/test_utils.py index cd2f8113..2e1af2a2 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/op_architecture_tests/test_utils.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/op_architecture_tests/test_utils.py @@ -614,8 +614,8 @@ def _test_tensor_transformation_params( tensor_data, quantization_params ) self.assertSequenceEqual( - list(expected_quantized_data.flatten()), - list(quantization_params.quantized_data.flatten()), # pytype: disable=attribute-error + np.ravel(expected_quantized_data).tolist(), + np.ravel(quantization_params.quantized_data).tolist(), # pytype: disable=attribute-error ) elif expected_tensor_max: max_q = 2**tensor_quant_config.num_bits / 2 - 1 diff --git a/ai_edge_quantizer/examples/mnist/quantize_toy_model.py b/ai_edge_quantizer/examples/mnist/quantize_toy_model.py index 518c795e..8278c2ee 100644 --- a/ai_edge_quantizer/examples/mnist/quantize_toy_model.py +++ b/ai_edge_quantizer/examples/mnist/quantize_toy_model.py @@ -153,7 +153,10 @@ def _get_calibration_data( # 5) Save the quantized model and the recipe used to the filesystem. quant_result.save(_OUTPUT_DIR.value, model_name='mnist_toy_model') - return quant_result.quantized_model + quantized_model = quant_result.quantized_model + if not isinstance(quantized_model, bytearray): + quantized_model = bytearray(quantized_model) + return quantized_model def inference(quantized_tflite: bytes, image_path: str): diff --git a/ai_edge_quantizer/model_modifier.py b/ai_edge_quantizer/model_modifier.py index b8b14f40..df2ebbe9 100644 --- a/ai_edge_quantizer/model_modifier.py +++ b/ai_edge_quantizer/model_modifier.py @@ -36,13 +36,13 @@ class ModelModifier: """Model Modifier class that produce the final quantized TFlite model.""" - def __init__(self, float_model: tfl_flatbuffer_utils.ModelT): + def __init__(self, float_model: qtyping.ModelT): """Constructor. Args: float_model: the original TFlite model. """ - self._model: tfl_flatbuffer_utils.ModelT = float_model + self._model: qtyping.ModelT = float_model self._constant_map = [] self._transformation_instruction_generator = ( @@ -55,7 +55,7 @@ def __init__(self, float_model: tfl_flatbuffer_utils.ModelT): def _get_tensor_processing_order( self, tensor_names: set[str], - flatbuffer_model: tfl_flatbuffer_utils.ModelT, + flatbuffer_model: qtyping.ModelT, ) -> list[str]: """Get the tensor processing order obtained from `buffer_to_tensors`. @@ -144,10 +144,10 @@ def modify_model( def _update_signature_defs( self, - model: tfl_flatbuffer_utils.ModelT, + model: qtyping.ModelT, serialized_model: bytearray, suffix: str, - ) -> tfl_flatbuffer_utils.ModelT: + ) -> qtyping.ModelT: """Updates the signature definitions in the model. This function is called when a transformation (quantize or dequantize) @@ -220,9 +220,7 @@ def _has_transform_before_output( return True return False - def _process_constant_map( - self, quantized_model: tfl_flatbuffer_utils.ModelT - ) -> int: + def _process_constant_map(self, quantized_model: qtyping.ModelT) -> int: """Process the constant map after all transformations are applied. If the resulting model is > 2GB then we would need to serialize constants @@ -256,7 +254,7 @@ def _pad_bytearray(self, bytearr: bytearray): # TODO: b/333797307 - support > 2GB output model def _serialize_large_model( - self, quantized_model: tfl_flatbuffer_utils.ModelT + self, quantized_model: qtyping.ModelT ) -> bytearray: """serialize models > 2GB. @@ -304,7 +302,7 @@ def _serialize_large_model( return model_bytearray def _serialize_small_model( - self, quantized_model: tfl_flatbuffer_utils.ModelT + self, quantized_model: qtyping.ModelT ) -> bytearray: """serialize models < 2GB. diff --git a/ai_edge_quantizer/model_modifier_test.py b/ai_edge_quantizer/model_modifier_test.py index 1cf9791e..65541389 100644 --- a/ai_edge_quantizer/model_modifier_test.py +++ b/ai_edge_quantizer/model_modifier_test.py @@ -84,7 +84,7 @@ def test_modify_model_succeeds_with_recipe(self): ) self.assertIsInstance( flatbuffer_utils.convert_bytearray_to_object(new_model_binary), - tfl_flatbuffer_utils.ModelT, + qtyping.ModelT, ) self.assertLess(len(new_model_binary), len(self._model_content)) diff --git a/ai_edge_quantizer/params_generator.py b/ai_edge_quantizer/params_generator.py index a2808370..55f108cf 100644 --- a/ai_edge_quantizer/params_generator.py +++ b/ai_edge_quantizer/params_generator.py @@ -34,8 +34,8 @@ class ParamsGenerator: """Generate model tensor level quantization parameters.""" - def __init__(self, float_tflite: tfl_flatbuffer_utils.ModelT): - self.float_model: tfl_flatbuffer_utils.ModelT = float_tflite + def __init__(self, float_tflite: qtyping.ModelT): + self.float_model: qtyping.ModelT = float_tflite if not tfl_flatbuffer_utils.is_float_model(self.float_model): warnings.warn( diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 25d0a04d..b7fe7586 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -25,9 +25,35 @@ import numpy as np from typing_extensions import TypeAlias +from ai_edge_litert.tools import flatbuffer_utils QSV: TypeAlias = MutableMapping[str, Any] +# Types imported from `schema_py_generated`. +Buffer = flatbuffer_utils.Buffer +BufferT = flatbuffer_utils.BufferT +BuiltinOperator = flatbuffer_utils.BuiltinOperator +BuiltinOptions = flatbuffer_utils.BuiltinOptions +BuiltinOptions2 = flatbuffer_utils.BuiltinOptions2 +Model = flatbuffer_utils.Model +ModelT = flatbuffer_utils.ModelT +Operator = flatbuffer_utils.Operator +OperatorCode = flatbuffer_utils.OperatorCode +OperatorCodeT = flatbuffer_utils.OperatorCodeT +OperatorT = flatbuffer_utils.OperatorT +StableHLOCompositeOptions = flatbuffer_utils.StableHLOCompositeOptions +StableHLOCompositeOptionsT = flatbuffer_utils.StableHLOCompositeOptionsT +SubGraph = flatbuffer_utils.SubGraph +SubGraphT = flatbuffer_utils.SubGraphT +Tensor = flatbuffer_utils.Tensor +TensorT = flatbuffer_utils.TensorT +TensorType = flatbuffer_utils.TensorType + +# Local convenience types. +Path = flatbuffer_utils.Path +BufferType = flatbuffer_utils.BufferType +Endiness = flatbuffer_utils.Endiness + class TFLOperationName(str, enum.Enum): """TF Lite operation names.""" @@ -463,8 +489,8 @@ class GraphInfo: buffers: Buffers in the subgraph. """ - subgraph_tensors: list[Any] - buffers: list[Any] + subgraph_tensors: list[TensorT] + buffers: list[BufferT] @dataclasses.dataclass(frozen=True) @@ -478,7 +504,7 @@ class OpInfo: op_quant_config: The quantization configuration for the op. """ - op: Any + op: OperatorT op_name: TFLOperationName subgraph_op_index: int # Position of the op in the subgraph. op_quant_config: OpQuantizationConfig @@ -578,7 +604,6 @@ class IOOperator: outputs: list[int] op_key: TFLOperationName - # The function signature for `get_tensor_quant_params_fn`. GetTensorQuantParamsFuncSignature = Callable[ [ diff --git a/ai_edge_quantizer/quantizer.py b/ai_edge_quantizer/quantizer.py index 39d36a76..0a8a2b8f 100644 --- a/ai_edge_quantizer/quantizer.py +++ b/ai_edge_quantizer/quantizer.py @@ -59,7 +59,7 @@ class QuantizationResult: """ recipe: _QuantRecipe - quantized_model: Optional[bytearray] + quantized_model: Optional[qtyping.BufferType] def save( self, save_folder: Path, model_name: str, overwrite: bool = False @@ -140,9 +140,11 @@ class Quantizer: def __init__( self, - float_model: Union[Path, bytes, bytearray, memoryview], + float_model: Union[Path, qtyping.BufferType], quantization_recipe: Optional[Union[Path, _QuantRecipe]] = None, - previous_quantized_model: Optional[Union[Path, bytearray]] = None, + previous_quantized_model: Optional[ + Union[Path, qtyping.BufferType] + ] = None, ): """Initializes the quantizer. @@ -172,8 +174,8 @@ def __init__( # Extract the `float_model` from the buffer. Note that this will not # duplicate the model's data, i.e. all arrays are views on the data of the # underlying buffer. - self._float_model: tfl_flatbuffer_utils.ModelT = ( - tfl_flatbuffer_utils.read_model(self._float_model_buffer) + self._float_model: qtyping.ModelT = tfl_flatbuffer_utils.read_model( + self._float_model_buffer ) self._recipe_manager: recipe_manager.RecipeManager = ( diff --git a/ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py b/ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py index 91dbde91..9e3ca1cb 100644 --- a/ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py +++ b/ai_edge_quantizer/transformations/insert_decomposed_hadamard_rotation.py @@ -198,7 +198,7 @@ def insert_decomposed_hadamard_rotation( hadamard_matrix_tensor_id = transformation_utils.add_new_constant_tensor( tensor_name=tensor.name + b'_hadamard_matrix', data=transformation_utils.pack_data( - bitwidth=4, flattened_data=hadamard_matrix.flatten() + bitwidth=4, flattened_data=np.ravel(hadamard_matrix) ), tensor_type=schema_py_generated.TensorType.INT4, subgraph=transformation_input.subgraph, diff --git a/ai_edge_quantizer/transformations/quantize_tensor.py b/ai_edge_quantizer/transformations/quantize_tensor.py index 5427df7a..b05cfa11 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor.py +++ b/ai_edge_quantizer/transformations/quantize_tensor.py @@ -15,9 +15,11 @@ """quantize a given tensor.""" -from typing import Optional, cast +from typing import Optional + import ml_dtypes import numpy as np + from ai_edge_quantizer import qtyping from ai_edge_quantizer.transformations import transformation_utils from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import @@ -84,13 +86,13 @@ def _perform_channelwise_quantization( transformation_input.quant_params, qtyping.UniformQuantParams ) flatbuffer_quantization = schema_py_generated.QuantizationParametersT() - flatbuffer_quantization.scale = list( - transformation_input.quant_params.scale.flatten().astype(np.float32) - ) # Flatbuffer requires scale as list[float]. + flatbuffer_quantization.scale = np.ravel( + transformation_input.quant_params.scale + ).astype(np.float32) if transformation_input.quant_params.zero_point is not None: - flatbuffer_quantization.zeroPoint = list( - transformation_input.quant_params.zero_point.flatten().astype(np.int64) - ) # Flatbuffer requires zeroPoint as list[int64] + flatbuffer_quantization.zeroPoint = np.ravel( + transformation_input.quant_params.zero_point + ).astype(np.int64) if transformation_input.quant_params.quantized_dimension is not None: flatbuffer_quantization.quantizedDimension = ( transformation_input.quant_params.quantized_dimension @@ -157,7 +159,9 @@ def quantize_tensor( num_ops_added: The total number of ops inserted by this operation, which is 0. """ - tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id] + tensor: schema_py_generated.TensorT = transformation_input.subgraph.tensors[ + transformation_input.tensor_id + ] # TODO: b/336385820 - Suppport quantize buffer directly when quantized_data # is not provided. if tensor.buffer: @@ -165,13 +169,9 @@ def quantize_tensor( transformation_input.buffers[tensor.buffer].data = ( transformation_utils.pack_data( transformation_input.quant_params.num_bits, - np.frombuffer( - cast( - np.ndarray, - transformation_input.quant_params.quantized_data, - ).tobytes(), - dtype=np.uint8, - ).flatten(), + np.ravel( + np.asarray(transformation_input.quant_params.quantized_data) + ).view(np.uint8), ) ) diff --git a/ai_edge_quantizer/transformations/transformation_utils.py b/ai_edge_quantizer/transformations/transformation_utils.py index 92292abf..3552b786 100644 --- a/ai_edge_quantizer/transformations/transformation_utils.py +++ b/ai_edge_quantizer/transformations/transformation_utils.py @@ -109,7 +109,7 @@ def get_constant_buffer( if isinstance(data, np.ndarray): # in the case where the data is passed from quantization_params. - new_data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten() + new_data = np.ravel(data.view(np.uint8)) elif isinstance(data, bytes): # in the case where the data is coming from duplicating buffers, we need to # make a copy of the data to avoid having two buffers pointing to the same diff --git a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py index 06db7813..4ae1592a 100644 --- a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +++ b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py @@ -15,11 +15,11 @@ """flatbuffer utils for the Quantizer.""" +import collections import logging import mmap import os import pathlib -from typing import Any, Optional, Union import immutabledict import numpy as np @@ -28,70 +28,65 @@ import io from ai_edge_litert.tools import flatbuffer_utils from ai_edge_quantizer import qtyping -from ai_edge_litert import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import - - -# Export some common schema types. -ModelT = schema.ModelT _TFLOpName = qtyping.TFLOperationName -Path = Union[str, pathlib.Path] +Path = str | pathlib.Path TFL_OP_NAME_TO_CODE = immutabledict.immutabledict({ - _TFLOpName.FULLY_CONNECTED: schema.BuiltinOperator.FULLY_CONNECTED, - _TFLOpName.BATCH_MATMUL: schema.BuiltinOperator.BATCH_MATMUL, - _TFLOpName.CONV_2D: schema.BuiltinOperator.CONV_2D, - _TFLOpName.DEPTHWISE_CONV_2D: schema.BuiltinOperator.DEPTHWISE_CONV_2D, - _TFLOpName.CONV_2D_TRANSPOSE: schema.BuiltinOperator.TRANSPOSE_CONV, - _TFLOpName.EMBEDDING_LOOKUP: schema.BuiltinOperator.EMBEDDING_LOOKUP, - _TFLOpName.SOFTMAX: schema.BuiltinOperator.SOFTMAX, - _TFLOpName.AVERAGE_POOL_2D: schema.BuiltinOperator.AVERAGE_POOL_2D, - _TFLOpName.RESHAPE: schema.BuiltinOperator.RESHAPE, - _TFLOpName.TANH: schema.BuiltinOperator.TANH, - _TFLOpName.TRANSPOSE: schema.BuiltinOperator.TRANSPOSE, - _TFLOpName.GELU: schema.BuiltinOperator.GELU, - _TFLOpName.ADD: schema.BuiltinOperator.ADD, - _TFLOpName.SUB: schema.BuiltinOperator.SUB, - _TFLOpName.MUL: schema.BuiltinOperator.MUL, - _TFLOpName.MEAN: schema.BuiltinOperator.MEAN, - _TFLOpName.RSQRT: schema.BuiltinOperator.RSQRT, - _TFLOpName.CONCATENATION: schema.BuiltinOperator.CONCATENATION, - _TFLOpName.STRIDED_SLICE: schema.BuiltinOperator.STRIDED_SLICE, - _TFLOpName.SPLIT: schema.BuiltinOperator.SPLIT, - _TFLOpName.LOGISTIC: schema.BuiltinOperator.LOGISTIC, - _TFLOpName.SLICE: schema.BuiltinOperator.SLICE, - _TFLOpName.SUM: schema.BuiltinOperator.SUM, - _TFLOpName.SELECT: schema.BuiltinOperator.SELECT, - _TFLOpName.SELECT_V2: schema.BuiltinOperator.SELECT_V2, - _TFLOpName.STABLEHLO_COMPOSITE: schema.BuiltinOperator.STABLEHLO_COMPOSITE, + _TFLOpName.FULLY_CONNECTED: qtyping.BuiltinOperator.FULLY_CONNECTED, + _TFLOpName.BATCH_MATMUL: qtyping.BuiltinOperator.BATCH_MATMUL, + _TFLOpName.CONV_2D: qtyping.BuiltinOperator.CONV_2D, + _TFLOpName.DEPTHWISE_CONV_2D: qtyping.BuiltinOperator.DEPTHWISE_CONV_2D, + _TFLOpName.CONV_2D_TRANSPOSE: qtyping.BuiltinOperator.TRANSPOSE_CONV, + _TFLOpName.EMBEDDING_LOOKUP: qtyping.BuiltinOperator.EMBEDDING_LOOKUP, + _TFLOpName.SOFTMAX: qtyping.BuiltinOperator.SOFTMAX, + _TFLOpName.AVERAGE_POOL_2D: qtyping.BuiltinOperator.AVERAGE_POOL_2D, + _TFLOpName.RESHAPE: qtyping.BuiltinOperator.RESHAPE, + _TFLOpName.TANH: qtyping.BuiltinOperator.TANH, + _TFLOpName.TRANSPOSE: qtyping.BuiltinOperator.TRANSPOSE, + _TFLOpName.GELU: qtyping.BuiltinOperator.GELU, + _TFLOpName.ADD: qtyping.BuiltinOperator.ADD, + _TFLOpName.SUB: qtyping.BuiltinOperator.SUB, + _TFLOpName.MUL: qtyping.BuiltinOperator.MUL, + _TFLOpName.MEAN: qtyping.BuiltinOperator.MEAN, + _TFLOpName.RSQRT: qtyping.BuiltinOperator.RSQRT, + _TFLOpName.CONCATENATION: qtyping.BuiltinOperator.CONCATENATION, + _TFLOpName.STRIDED_SLICE: qtyping.BuiltinOperator.STRIDED_SLICE, + _TFLOpName.SPLIT: qtyping.BuiltinOperator.SPLIT, + _TFLOpName.LOGISTIC: qtyping.BuiltinOperator.LOGISTIC, + _TFLOpName.SLICE: qtyping.BuiltinOperator.SLICE, + _TFLOpName.SUM: qtyping.BuiltinOperator.SUM, + _TFLOpName.SELECT: qtyping.BuiltinOperator.SELECT, + _TFLOpName.SELECT_V2: qtyping.BuiltinOperator.SELECT_V2, + _TFLOpName.STABLEHLO_COMPOSITE: qtyping.BuiltinOperator.STABLEHLO_COMPOSITE, _TFLOpName.DYNAMIC_UPDATE_SLICE: ( - schema.BuiltinOperator.DYNAMIC_UPDATE_SLICE + qtyping.BuiltinOperator.DYNAMIC_UPDATE_SLICE ), - _TFLOpName.PAD: schema.BuiltinOperator.PAD, - _TFLOpName.SQUARED_DIFFERENCE: schema.BuiltinOperator.SQUARED_DIFFERENCE, - _TFLOpName.MAX_POOL_2D: schema.BuiltinOperator.MAX_POOL_2D, - _TFLOpName.RESIZE_BILINEAR: schema.BuiltinOperator.RESIZE_BILINEAR, + _TFLOpName.PAD: qtyping.BuiltinOperator.PAD, + _TFLOpName.SQUARED_DIFFERENCE: qtyping.BuiltinOperator.SQUARED_DIFFERENCE, + _TFLOpName.MAX_POOL_2D: qtyping.BuiltinOperator.MAX_POOL_2D, + _TFLOpName.RESIZE_BILINEAR: qtyping.BuiltinOperator.RESIZE_BILINEAR, _TFLOpName.RESIZE_NEAREST_NEIGHBOR: ( - schema.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR + qtyping.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR ), - _TFLOpName.GATHER_ND: schema.BuiltinOperator.GATHER_ND, - _TFLOpName.PACK: schema.BuiltinOperator.PACK, - _TFLOpName.UNPACK: schema.BuiltinOperator.UNPACK, - _TFLOpName.DIV: schema.BuiltinOperator.DIV, - _TFLOpName.BROADCAST_TO: schema.BuiltinOperator.BROADCAST_TO, - _TFLOpName.SQRT: schema.BuiltinOperator.SQRT, - _TFLOpName.GATHER: schema.BuiltinOperator.GATHER, - _TFLOpName.HARD_SWISH: schema.BuiltinOperator.HARD_SWISH, - _TFLOpName.MAXIMUM: schema.BuiltinOperator.MAXIMUM, - _TFLOpName.PADV2: schema.BuiltinOperator.PADV2, - _TFLOpName.REDUCE_MIN: schema.BuiltinOperator.REDUCE_MIN, - _TFLOpName.EQUAL: schema.BuiltinOperator.EQUAL, - _TFLOpName.NOT_EQUAL: schema.BuiltinOperator.NOT_EQUAL, - _TFLOpName.MIRROR_PAD: schema.BuiltinOperator.MIRROR_PAD, - _TFLOpName.SPACE_TO_DEPTH: schema.BuiltinOperator.SPACE_TO_DEPTH, - _TFLOpName.RELU: schema.BuiltinOperator.RELU, + _TFLOpName.GATHER_ND: qtyping.BuiltinOperator.GATHER_ND, + _TFLOpName.PACK: qtyping.BuiltinOperator.PACK, + _TFLOpName.UNPACK: qtyping.BuiltinOperator.UNPACK, + _TFLOpName.DIV: qtyping.BuiltinOperator.DIV, + _TFLOpName.BROADCAST_TO: qtyping.BuiltinOperator.BROADCAST_TO, + _TFLOpName.SQRT: qtyping.BuiltinOperator.SQRT, + _TFLOpName.GATHER: qtyping.BuiltinOperator.GATHER, + _TFLOpName.HARD_SWISH: qtyping.BuiltinOperator.HARD_SWISH, + _TFLOpName.MAXIMUM: qtyping.BuiltinOperator.MAXIMUM, + _TFLOpName.PADV2: qtyping.BuiltinOperator.PADV2, + _TFLOpName.REDUCE_MIN: qtyping.BuiltinOperator.REDUCE_MIN, + _TFLOpName.EQUAL: qtyping.BuiltinOperator.EQUAL, + _TFLOpName.NOT_EQUAL: qtyping.BuiltinOperator.NOT_EQUAL, + _TFLOpName.MIRROR_PAD: qtyping.BuiltinOperator.MIRROR_PAD, + _TFLOpName.SPACE_TO_DEPTH: qtyping.BuiltinOperator.SPACE_TO_DEPTH, + _TFLOpName.RELU: qtyping.BuiltinOperator.RELU, }) TFL_OP_CODE_TO_NAME = immutabledict.immutabledict( @@ -114,12 +109,9 @@ }) NUM_TFL_DATATYPES = 18 -TENSOR_CODE_TO_TYPE = {} -for dtype_code in range(NUM_TFL_DATATYPES): - TENSOR_CODE_TO_TYPE[dtype_code] = flatbuffer_utils.type_to_name(dtype_code) -TENSOR_CODE_TO_TYPE = immutabledict.immutabledict(TENSOR_CODE_TO_TYPE) -TENSOR_TYPE_TO_CODE = immutabledict.immutabledict( # pytype: disable=wrong-arg-types - (reversed(item) for item in TENSOR_CODE_TO_TYPE.items()) +TENSOR_TYPE_TO_CODE = immutabledict.immutabledict(qtyping.TensorType.__dict__) +TENSOR_CODE_TO_TYPE = immutabledict.immutabledict( + {v: k for k, v in qtyping.TensorType.__dict__.items()} ) # Expose functions in litert.python.tools.flatbuffer_utils @@ -127,8 +119,8 @@ def read_model( - tflite_model: Union[Path, bytearray, bytes, memoryview], -) -> schema.ModelT: + tflite_model: Path | qtyping.BufferType, +) -> qtyping.ModelT: """Read and convert the TFLite model into a flatbuffer object. Args: @@ -210,7 +202,9 @@ def get_model_buffer(tflite_path: Path) -> bytearray: return model_bytearray -def parse_op_tensors(op: Any, subgraph_tensors: list[Any]) -> list[Any]: +def parse_op_tensors( + op: qtyping.OperatorT, subgraph_tensors: list[qtyping.TensorT] +) -> list[qtyping.TensorT]: """Parse the op tensors. Args: @@ -221,21 +215,23 @@ def parse_op_tensors(op: Any, subgraph_tensors: list[Any]) -> list[Any]: tensors: list of tensors that are associated with the op. """ - tensors = [] - for tensor_idx in list(op.outputs) + list(op.inputs): - if tensor_idx != -1: - tensors.append(subgraph_tensors[tensor_idx]) - return tensors + return [ + subgraph_tensors[tensor_idx] + for tensor_idx in list(op.outputs) + list(op.inputs) + if tensor_idx != -1 + ] def parse_fc_bmm_conv_tensors( - op: Any, - subgraph_tensors: list[Any], + op: qtyping.OperatorT, + subgraph_tensors: list[qtyping.TensorT], input_index: int = 0, weight_index: int = 1, bias_index: int = 2, output_index: int = 0, -) -> tuple[Any, Any, Any, Any]: +) -> tuple[ + qtyping.TensorT, qtyping.TensorT, qtyping.TensorT | None, qtyping.TensorT +]: """Parse tensors in FullyConnected, BatchMatmul, and Convolutions. Args: @@ -259,22 +255,20 @@ def parse_fc_bmm_conv_tensors( return input_tensor, weight_tensor, bias_tensor, output_tensor -# flatbuffer_model has Any type since litert.python.tools.flatbuffer_utils -# is not type annotated. -def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]: +def buffer_to_tensors( + flatbuffer_model: qtyping.ModelT, +) -> dict[int, list[qtyping.TensorT]]: """Returns a map from buffer id to tensors that use it.""" - buffer_to_tensor_map = {} + buffer_to_tensor_map = collections.defaultdict(list) for subgraph in flatbuffer_model.subgraphs: for op in subgraph.operators: for tensor in parse_op_tensors(op, subgraph.tensors): - if tensor.buffer not in buffer_to_tensor_map: - buffer_to_tensor_map[tensor.buffer] = [] if tensor not in buffer_to_tensor_map[tensor.buffer]: buffer_to_tensor_map[tensor.buffer].append(tensor) return buffer_to_tensor_map -def get_tensor_name(tensor: Any) -> str: +def get_tensor_name(tensor: qtyping.TensorT) -> str: """Get the tensor name for a fb tensor. Args: @@ -286,7 +280,9 @@ def get_tensor_name(tensor: Any) -> str: return tensor.name.decode("utf-8") -def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]: +def get_tensor_data( + tensor: qtyping.TensorT, buffers: list[qtyping.BufferT] +) -> np.ndarray | None: """Get the tensor data. Args: @@ -308,7 +304,9 @@ def get_tensor_data(tensor: Any, buffers: list[Any]) -> Optional[np.ndarray]: return data -def has_same_quantization(tensor1: Any, tensor2: Any) -> bool: +def has_same_quantization( + tensor1: qtyping.TensorT, tensor2: qtyping.TensorT +) -> bool: """Check if two tensors have the same quantization. Args: @@ -347,7 +345,7 @@ def to_tuple(val): ) -def is_float_model(flatbuffer_model: Any) -> bool: +def is_float_model(flatbuffer_model: qtyping.ModelT) -> bool: """Checks that the model is float and not already quantized.""" for subgraph in flatbuffer_model.subgraphs: for tensor in subgraph.tensors: @@ -359,7 +357,7 @@ def is_float_model(flatbuffer_model: Any) -> bool: def get_subgraph_input_output_operators( - subgraph: Any, + subgraph: qtyping.SubGraphT, ) -> list[qtyping.IOOperator]: """Get the input/output operators for the subgraph. @@ -383,7 +381,7 @@ def get_subgraph_input_output_operators( def get_op_side_effect_subgraphs( - op: Union[schema.Operator, schema.OperatorT], + op: qtyping.Operator | qtyping.OperatorT, ) -> list[int]: """Get indices of any subgraphs invoked as a side effect of the operator. @@ -395,7 +393,7 @@ def get_op_side_effect_subgraphs( does not invoke any subgraphs. """ if opts := flatbuffer_utils.get_options_as( - op, schema.StableHLOCompositeOptionsT + op, qtyping.StableHLOCompositeOptionsT ): return [opts.decompositionSubgraphIndex] # Can add other nested ops here (control flow ops, etc). @@ -403,7 +401,7 @@ def get_op_side_effect_subgraphs( def get_op_name_by_index( - flatbuffer_model: Any, subgraph_id: int, op_index: int + flatbuffer_model: qtyping.ModelT, subgraph_id: int, op_index: int ) -> str: """Get the op name from the flatbuffer model.""" op = flatbuffer_model.subgraphs[subgraph_id].operators[op_index] @@ -412,7 +410,9 @@ def get_op_name_by_index( def get_op_scope( - op: Any, subgraph_tensors: list[Any], max_length: int = 10000 + op: qtyping.OperatorT, + subgraph_tensors: list[qtyping.TensorT], + max_length: int = 10000, ) -> str: """Get the op scope. diff --git a/ai_edge_quantizer/utils/validation_utils.py b/ai_edge_quantizer/utils/validation_utils.py index 8ebeb27b..9a7d0331 100644 --- a/ai_edge_quantizer/utils/validation_utils.py +++ b/ai_edge_quantizer/utils/validation_utils.py @@ -225,11 +225,11 @@ def _preprocess_same_size_arrays( Raises: ValueError: if the two inputs don't have the same number of elements """ - data1 = np.array(data1, dtype=np.float32).flatten() - data2 = np.array(data2, dtype=np.float32).flatten() + data1 = np.ravel(np.asarray(data1, dtype=np.float32)) + data2 = np.ravel(np.asarray(data2, dtype=np.float32)) if np.shape(data1) != np.shape(data2): raise ValueError("data1 & data2 must be of the same size") - data1 = np.nan_to_num(data1, nan=1e-9, neginf=-1e9, posinf=1e9) - data2 = np.nan_to_num(data2, nan=1e-9, neginf=-1e9, posinf=1e9) + data1 = np.nan_to_num(data1, nan=1e-9, neginf=-1e9, posinf=1e9, copy=False) + data2 = np.nan_to_num(data2, nan=1e-9, neginf=-1e9, posinf=1e9, copy=False) return data1, data2