Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import dataclasses
from typing import Any, Optional

import numpy as np

from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
Expand All @@ -35,7 +37,7 @@ def _validate_recovered_weights(
scale: np.ndarray,
tol: float = 1e-4,
):
"""Validates if recovered weights (from the quantized values) are close enough to the original ones.
"""Validates if requantized weights are close enough to the original ones.

Args:
original_vals: Original values before quantization.
Expand All @@ -47,8 +49,9 @@ def _validate_recovered_weights(
RuntimeError: If the maximum difference between original and recovered
values exceeds the tolerance.
"""

recovered_vals = quant_vals * scale
diff = np.abs(recovered_vals - original_vals).flatten()
diff = np.ravel(np.abs(recovered_vals - original_vals))
max_diff = diff.max()
if max_diff > tol:
raise RuntimeError(
Expand Down Expand Up @@ -104,22 +107,20 @@ def get_zp_scale_from_dequantized_symmetric_weights(

if quantized_dimension is None:
# Per-tensor quantization: One scale for the entire tensor.
scales = _get_scale(dequant_vals.flatten(), min_scale)
scales = _get_scale(np.ravel(dequant_vals), min_scale)
scales = np.array([[scales]])
else:
# Per-channel quantization: A scale for each slice along the dimension.
# Create a broadcasted array for per-channel scales. It should have the same
# number of dimensions as the input, with 1 in all dimensions except for the
# quantized dimension, which retains its original size.
scales = np.empty(
tuple(
[
1
if i != quantized_dimension
else dequant_vals.shape[quantized_dimension]
for i in range(dequant_vals.ndim)
]
)
tuple([
1
if i != quantized_dimension
else dequant_vals.shape[quantized_dimension]
for i in range(dequant_vals.ndim)
])
)
for i in range(dequant_vals.shape[quantized_dimension]):
slices = [slice(None)] * dequant_vals.ndim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,8 @@ def _test_tensor_transformation_params(
tensor_data, quantization_params
)
self.assertSequenceEqual(
list(expected_quantized_data.flatten()),
list(quantization_params.quantized_data.flatten()), # pytype: disable=attribute-error
np.ravel(expected_quantized_data).tolist(),
np.ravel(quantization_params.quantized_data).tolist(), # pytype: disable=attribute-error
)
elif expected_tensor_max:
max_q = 2**tensor_quant_config.num_bits / 2 - 1
Expand Down
5 changes: 4 additions & 1 deletion ai_edge_quantizer/examples/mnist/quantize_toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def _get_calibration_data(
# 5) Save the quantized model and the recipe used to the filesystem.
quant_result.save(_OUTPUT_DIR.value, model_name='mnist_toy_model')

return quant_result.quantized_model
quantized_model = quant_result.quantized_model
if not isinstance(quantized_model, bytearray):
quantized_model = bytearray(quantized_model)
return quantized_model


def inference(quantized_tflite: bytes, image_path: str):
Expand Down
18 changes: 8 additions & 10 deletions ai_edge_quantizer/model_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
class ModelModifier:
"""Model Modifier class that produce the final quantized TFlite model."""

def __init__(self, float_model: tfl_flatbuffer_utils.ModelT):
def __init__(self, float_model: qtyping.ModelT):
"""Constructor.

Args:
float_model: the original TFlite model.
"""
self._model: tfl_flatbuffer_utils.ModelT = float_model
self._model: qtyping.ModelT = float_model

self._constant_map = []
self._transformation_instruction_generator = (
Expand All @@ -55,7 +55,7 @@ def __init__(self, float_model: tfl_flatbuffer_utils.ModelT):
def _get_tensor_processing_order(
self,
tensor_names: set[str],
flatbuffer_model: tfl_flatbuffer_utils.ModelT,
flatbuffer_model: qtyping.ModelT,
) -> list[str]:
"""Get the tensor processing order obtained from `buffer_to_tensors`.

Expand Down Expand Up @@ -144,10 +144,10 @@ def modify_model(

def _update_signature_defs(
self,
model: tfl_flatbuffer_utils.ModelT,
model: qtyping.ModelT,
serialized_model: bytearray,
suffix: str,
) -> tfl_flatbuffer_utils.ModelT:
) -> qtyping.ModelT:
"""Updates the signature definitions in the model.

This function is called when a transformation (quantize or dequantize)
Expand Down Expand Up @@ -220,9 +220,7 @@ def _has_transform_before_output(
return True
return False

def _process_constant_map(
self, quantized_model: tfl_flatbuffer_utils.ModelT
) -> int:
def _process_constant_map(self, quantized_model: qtyping.ModelT) -> int:
"""Process the constant map after all transformations are applied.

If the resulting model is > 2GB then we would need to serialize constants
Expand Down Expand Up @@ -256,7 +254,7 @@ def _pad_bytearray(self, bytearr: bytearray):

# TODO: b/333797307 - support > 2GB output model
def _serialize_large_model(
self, quantized_model: tfl_flatbuffer_utils.ModelT
self, quantized_model: qtyping.ModelT
) -> bytearray:
"""serialize models > 2GB.

Expand Down Expand Up @@ -304,7 +302,7 @@ def _serialize_large_model(
return model_bytearray

def _serialize_small_model(
self, quantized_model: tfl_flatbuffer_utils.ModelT
self, quantized_model: qtyping.ModelT
) -> bytearray:
"""serialize models < 2GB.

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/model_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_modify_model_succeeds_with_recipe(self):
)
self.assertIsInstance(
flatbuffer_utils.convert_bytearray_to_object(new_model_binary),
tfl_flatbuffer_utils.ModelT,
qtyping.ModelT,
)
self.assertLess(len(new_model_binary), len(self._model_content))

Expand Down
4 changes: 2 additions & 2 deletions ai_edge_quantizer/params_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
class ParamsGenerator:
"""Generate model tensor level quantization parameters."""

def __init__(self, float_tflite: tfl_flatbuffer_utils.ModelT):
self.float_model: tfl_flatbuffer_utils.ModelT = float_tflite
def __init__(self, float_tflite: qtyping.ModelT):
self.float_model: qtyping.ModelT = float_tflite

if not tfl_flatbuffer_utils.is_float_model(self.float_model):
warnings.warn(
Expand Down
33 changes: 29 additions & 4 deletions ai_edge_quantizer/qtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,35 @@
import numpy as np
from typing_extensions import TypeAlias

from ai_edge_litert.tools import flatbuffer_utils

QSV: TypeAlias = MutableMapping[str, Any]

# Types imported from `schema_py_generated`.
Buffer = flatbuffer_utils.Buffer
BufferT = flatbuffer_utils.BufferT
BuiltinOperator = flatbuffer_utils.BuiltinOperator
BuiltinOptions = flatbuffer_utils.BuiltinOptions
BuiltinOptions2 = flatbuffer_utils.BuiltinOptions2
Model = flatbuffer_utils.Model
ModelT = flatbuffer_utils.ModelT
Operator = flatbuffer_utils.Operator
OperatorCode = flatbuffer_utils.OperatorCode
OperatorCodeT = flatbuffer_utils.OperatorCodeT
OperatorT = flatbuffer_utils.OperatorT
StableHLOCompositeOptions = flatbuffer_utils.StableHLOCompositeOptions
StableHLOCompositeOptionsT = flatbuffer_utils.StableHLOCompositeOptionsT
SubGraph = flatbuffer_utils.SubGraph
SubGraphT = flatbuffer_utils.SubGraphT
Tensor = flatbuffer_utils.Tensor
TensorT = flatbuffer_utils.TensorT
TensorType = flatbuffer_utils.TensorType

# Local convenience types.
Path = flatbuffer_utils.Path
BufferType = flatbuffer_utils.BufferType
Endiness = flatbuffer_utils.Endiness


class TFLOperationName(str, enum.Enum):
"""TF Lite operation names."""
Expand Down Expand Up @@ -463,8 +489,8 @@ class GraphInfo:
buffers: Buffers in the subgraph.
"""

subgraph_tensors: list[Any]
buffers: list[Any]
subgraph_tensors: list[TensorT]
buffers: list[BufferT]


@dataclasses.dataclass(frozen=True)
Expand All @@ -478,7 +504,7 @@ class OpInfo:
op_quant_config: The quantization configuration for the op.
"""

op: Any
op: OperatorT
op_name: TFLOperationName
subgraph_op_index: int # Position of the op in the subgraph.
op_quant_config: OpQuantizationConfig
Expand Down Expand Up @@ -578,7 +604,6 @@ class IOOperator:
outputs: list[int]
op_key: TFLOperationName


# The function signature for `get_tensor_quant_params_fn`.
GetTensorQuantParamsFuncSignature = Callable[
[
Expand Down
12 changes: 7 additions & 5 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class QuantizationResult:
"""

recipe: _QuantRecipe
quantized_model: Optional[bytearray]
quantized_model: Optional[qtyping.BufferType]

def save(
self, save_folder: Path, model_name: str, overwrite: bool = False
Expand Down Expand Up @@ -140,9 +140,11 @@ class Quantizer:

def __init__(
self,
float_model: Union[Path, bytes, bytearray, memoryview],
float_model: Union[Path, qtyping.BufferType],
quantization_recipe: Optional[Union[Path, _QuantRecipe]] = None,
previous_quantized_model: Optional[Union[Path, bytearray]] = None,
previous_quantized_model: Optional[
Union[Path, qtyping.BufferType]
] = None,
):
"""Initializes the quantizer.

Expand Down Expand Up @@ -172,8 +174,8 @@ def __init__(
# Extract the `float_model` from the buffer. Note that this will not
# duplicate the model's data, i.e. all arrays are views on the data of the
# underlying buffer.
self._float_model: tfl_flatbuffer_utils.ModelT = (
tfl_flatbuffer_utils.read_model(self._float_model_buffer)
self._float_model: qtyping.ModelT = tfl_flatbuffer_utils.read_model(
self._float_model_buffer
)

self._recipe_manager: recipe_manager.RecipeManager = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def insert_decomposed_hadamard_rotation(
hadamard_matrix_tensor_id = transformation_utils.add_new_constant_tensor(
tensor_name=tensor.name + b'_hadamard_matrix',
data=transformation_utils.pack_data(
bitwidth=4, flattened_data=hadamard_matrix.flatten()
bitwidth=4, flattened_data=np.ravel(hadamard_matrix)
),
tensor_type=schema_py_generated.TensorType.INT4,
subgraph=transformation_input.subgraph,
Expand Down
30 changes: 15 additions & 15 deletions ai_edge_quantizer/transformations/quantize_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

"""quantize a given tensor."""

from typing import Optional, cast
from typing import Optional

import ml_dtypes
import numpy as np

from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import transformation_utils
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import
Expand Down Expand Up @@ -84,13 +86,13 @@ def _perform_channelwise_quantization(
transformation_input.quant_params, qtyping.UniformQuantParams
)
flatbuffer_quantization = schema_py_generated.QuantizationParametersT()
flatbuffer_quantization.scale = list(
transformation_input.quant_params.scale.flatten().astype(np.float32)
) # Flatbuffer requires scale as list[float].
flatbuffer_quantization.scale = np.ravel(
transformation_input.quant_params.scale
).astype(np.float32)
if transformation_input.quant_params.zero_point is not None:
flatbuffer_quantization.zeroPoint = list(
transformation_input.quant_params.zero_point.flatten().astype(np.int64)
) # Flatbuffer requires zeroPoint as list[int64]
flatbuffer_quantization.zeroPoint = np.ravel(
transformation_input.quant_params.zero_point
).astype(np.int64)
if transformation_input.quant_params.quantized_dimension is not None:
flatbuffer_quantization.quantizedDimension = (
transformation_input.quant_params.quantized_dimension
Expand Down Expand Up @@ -157,21 +159,19 @@ def quantize_tensor(
num_ops_added: The total number of ops inserted by this operation, which
is 0.
"""
tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id]
tensor: schema_py_generated.TensorT = transformation_input.subgraph.tensors[
transformation_input.tensor_id
]
# TODO: b/336385820 - Suppport quantize buffer directly when quantized_data
# is not provided.
if tensor.buffer:
if transformation_input.quant_params.quantized_data is not None:
transformation_input.buffers[tensor.buffer].data = (
transformation_utils.pack_data(
transformation_input.quant_params.num_bits,
np.frombuffer(
cast(
np.ndarray,
transformation_input.quant_params.quantized_data,
).tobytes(),
dtype=np.uint8,
).flatten(),
np.ravel(
np.asarray(transformation_input.quant_params.quantized_data)
).view(np.uint8),
)
)

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_quantizer/transformations/transformation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_constant_buffer(

if isinstance(data, np.ndarray):
# in the case where the data is passed from quantization_params.
new_data = np.frombuffer(data.tobytes(), dtype=np.uint8).flatten()
new_data = np.ravel(data.view(np.uint8))
elif isinstance(data, bytes):
# in the case where the data is coming from duplicating buffers, we need to
# make a copy of the data to avoid having two buffers pointing to the same
Expand Down
Loading
Loading