Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 146 additions & 5 deletions ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,163 @@

from collections.abc import Callable, Iterable
import copy
import enum
import json
from typing import Any, Union

from absl import logging
import numpy as np

import os
from ai_edge_quantizer import algorithm_manager
from ai_edge_quantizer import default_policy as policy
from ai_edge_quantizer import qtyping
from ai_edge_quantizer import recipe
from ai_edge_quantizer import recipe_manager
from ai_edge_quantizer.utils import calibration_utils
from ai_edge_quantizer.utils import progress_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_quantizer.utils import tfl_interpreter_utils


class CalibrationMode(enum.Enum):
INFERENCE = 1
CALIBRATION = 2


_SignatureInput = dict[str, Any] # input_argument_name -> tensor_value.
_SignatureOutput = dict[
str, np.ndarray
] # output_argument_name -> tensor_value.


class CalibrationInterpreter:
"""A TFL interpreter-like interface for model calibration.

This is a wrapper around Calibrator that replaces the TFL Interpreter to
enable calibration. If mode is CALIBRATION, it runs calibration, otherwise it
acts as a regular TFL interpreter for inference. When in CALIBRATION mode,
each invocation of a signature runner will update the calibration statistics
in self._calibrator. Calibrator is needed in both modes because it contains
tfl interpreter instance to run the model.
"""

def __init__(
self,
model_path: str,
mode: CalibrationMode = CalibrationMode.INFERENCE,
):
"""Initializes the CalibrationInterpreter.

Args:
model_path: The path to the TFLite model.
mode: The mode of the interpreter. If CALIBRATION, the interpreter will
preserve all tensors for calibration purposes.
"""
self._calibrator = Calibrator(
model_path,
interpreter_preserve_all_tensors=(mode == CalibrationMode.CALIBRATION),
)
self._mode = mode

def get_signature_runner(self, signature_key: str | None = None):
"""Returns the signature runner."""
return CalibrationSignatureRunner(
self._calibrator, signature_key, self._mode
)

def get_calibration_results(self):
"""Returns the calibration results."""
if self._mode == CalibrationMode.INFERENCE:
raise ValueError(
"Calibration results are not available in INFERENCE mode."
)
return self._calibrator.get_model_qsvs()

def save_calibration_result(self, output_path: str):
"""Saves the calibration results."""
if self._mode == CalibrationMode.INFERENCE:
raise ValueError(
"Calibration results are not available in INFERENCE mode."
)
self._calibrator.save_calibration_result(output_path)

def get_signature_list(self) -> list[str]:
"""Returns the signature list."""
return self._calibrator.get_signature_list()


class CalibrationSignatureRunner:
"""Wrapper around TFL signature runner to enable calibration."""

def __init__(
self,
calibrator_obj: "Calibrator",
signature_key: str | None = None,
mode: CalibrationMode = CalibrationMode.INFERENCE,
quantization_recipe: recipe_manager.ModelQuantizationRecipe = recipe.static_wi8_ai8(),
):
"""Initializes the CalibrationSignatureRunner.

Args:
calibrator_obj: The Calibrator instance.
signature_key: The key of the signature to run. If None, the default
signature is used.
mode: The mode of the runner. If CALIBRATION, invoking the runner will
update 'calibrator_obj' with new quantization statistics values. If
INFERENCE, the runner behaves like a standard signature runner.
quantization_recipe: The quantization recipe to use for calibration.
Defaults to static_wi8_ai8.
"""
self._calibrator = calibrator_obj
self._signature_key = signature_key
self._mode = mode
self._recipe_manager = recipe_manager.RecipeManager()
self._recipe_manager.load_quantization_recipe(quantization_recipe)
self._signature_runner = (
self._calibrator._tfl_interpreter.get_signature_runner(
self._signature_key
)
)

def __call__(self, **kwargs):
if self._mode == CalibrationMode.INFERENCE:
return self._signature_runner(**kwargs)
self._calibrator.calibrate(
calibration_dataset={self._signature_key: [kwargs]},
model_recipe_manager=self._recipe_manager,
cache_output=True,
)
outputs = self._calibrator.get_cached_output()
assert len(outputs) == 1
self._calibrator.clear_cached_output()
return outputs[0]

def get_input_details(self):
"""Returns the input details of the model."""
return self._signature_runner.get_input_details()

def get_output_details(self):
"""Returns the output details of the model."""
return self._signature_runner.get_output_details()


class Calibrator:
"""Calibrator for TFLite model."""

def __init__(
self,
float_tflite: Union[str, bytes],
num_threads: int = 16,
interpreter_preserve_all_tensors: bool = True,
):
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)

self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
float_tflite, use_xnnpack=True, num_threads=num_threads
float_tflite,
use_xnnpack=True,
num_threads=num_threads,
preserve_all_tensors=interpreter_preserve_all_tensors,
)
# Tensor name to tensor content.
self._tensor_content_map: dict[str, Any] = {}
Expand Down Expand Up @@ -140,7 +265,7 @@ def calibrate(
disable=total_ops
< 1000, # We skip the progress bar for small models and small datasets.
) as pbar:
# TODO: b/329322226 - Enable parallel calibration.
# TODO: b/329322226 - Enable parallel calibration.
for signature_key, dataset in calibration_dataset.items():
# Step0: get subgraph index.
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
Expand Down Expand Up @@ -242,13 +367,29 @@ def reset_model_qsvs(self) -> None:
"""Reset the model qsvs."""
self._model_qsvs = {}

def load_model_qsvs(self, model_qsvs: dict[str, qtyping.QSV]) -> None:
def load_model_qsvs(
self, model_qsvs: Union[str, dict[str, qtyping.QSV]]
) -> None:
"""Load the model qsvs.

Args:
model_qsvs: A dictionary of tensor name to QSV.
model_qsvs: A dictionary of tensor name to QSV or a path to a JSON file
that contains the model qsvs (i.e., from save_calibration_result).
"""
self._model_qsvs = copy.deepcopy(model_qsvs)

if isinstance(model_qsvs, str):
self._model_qsvs = calibration_utils.load_calibration_results(model_qsvs)
else:
self._model_qsvs = copy.deepcopy(model_qsvs)

def save_calibration_result(self, file_path: str) -> None:
"""Saves the calibration result to a json file."""
with open(file_path, "w") as f:
json.dump(self._model_qsvs, f, cls=calibration_utils.NumpyEncoder)

def get_signature_list(self) -> list[str]:
"""Get the signature list of the model."""
return self._tfl_interpreter.get_signature_list()

def _update_qsvs(
self,
Expand Down
124 changes: 124 additions & 0 deletions ai_edge_quantizer/calibrator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,35 @@ def test_calibrate_reshape_with_empty_shape_success(self):
test_calibrator.calibrate(calib_data, self._recipe_manager)
self.assertNotEmpty(test_calibrator.get_model_qsvs())

def test_save_and_load_calibration_result(self):
# Setup some QSV
sample_qsv = {
"serving_default_input_1:0": {
"min": np.array([-10.0]),
"max": np.array([8.0]),
}
}
self._calibrator.load_model_qsvs(sample_qsv)

# Save
temp_file = self.create_tempfile().full_path
self._calibrator.save_calibration_result(temp_file)

# Reset
self._calibrator.reset_model_qsvs()
self.assertEmpty(self._calibrator.get_model_qsvs())

# Load
self._calibrator.load_model_qsvs(temp_file)

# Verify
model_tensor_qsvs = self._calibrator.get_model_qsvs()
self.assertLen(model_tensor_qsvs, 1)
self.assertIn("serving_default_input_1:0", model_tensor_qsvs)
input_qsv = model_tensor_qsvs["serving_default_input_1:0"]
self.assertSequenceAlmostEqual(input_qsv["min"], [-10.0])
self.assertSequenceAlmostEqual(input_qsv["max"], [8.0])


class CalibratorAlreadyQuantizedModelTest(absltest.TestCase):

Expand Down Expand Up @@ -238,5 +267,100 @@ def test_toy_gemma2_calibration_success(self):
self.assertLen(calib.get_model_qsvs(), 202)


class CalibrationInterpreterTest(absltest.TestCase):

def setUp(self):
super().setUp()
np.random.seed(0)
self._test_model_path = str(
pathlib.Path(TEST_DATA_PREFIX_PATH) / "tests/models/single_fc.tflite"
)
self._interpreter = calibrator.CalibrationInterpreter(
self._test_model_path, mode=calibrator.CalibrationMode.CALIBRATION
)

def test_initialization(self):
self.assertIsInstance(self._interpreter, calibrator.CalibrationInterpreter)

def test_get_signature_runner(self):
runner = self._interpreter.get_signature_runner()
self.assertIsInstance(runner, calibrator.CalibrationSignatureRunner)

def test_calibration_flow(self):
runner = self._interpreter.get_signature_runner()

# Run inference which triggers calibration
input_data = np.random.rand(1, 8).astype(np.float32)
output = runner(input_1=input_data)

# Check results
qsvs = self._interpreter.get_calibration_results()
self.assertNotEmpty(qsvs)

# Verify input tensor qsv
self.assertIn("serving_default_input_1:0", qsvs)
self.assertIsNotNone(output)

def test_disable_calibration(self):
interpreter = calibrator.CalibrationInterpreter(
self._test_model_path, mode=calibrator.CalibrationMode.INFERENCE
)
runner = interpreter.get_signature_runner()

input_data = np.random.rand(1, 8).astype(np.float32)
output = runner(input_1=input_data)

with self.assertRaisesRegex(
ValueError, "Calibration results are not available in INFERENCE mode."
):
interpreter.get_calibration_results()
self.assertIsNotNone(output)

def test_save_calibration_result(self):
runner = self._interpreter.get_signature_runner()
input_data = np.random.rand(1, 8).astype(np.float32)
runner(input_1=input_data)

temp_file = self.create_tempfile().full_path
self._interpreter.save_calibration_result(temp_file)

# Verify file exists and has content
with open(temp_file, "r") as f:
content = f.read()
self.assertNotEmpty(content)

def test_get_signature_list(self):
signatures = self._interpreter.get_signature_list()
self.assertNotEmpty(signatures)
self.assertIn("serving_default", signatures)

def test_runner_details(self):
runner = self._interpreter.get_signature_runner()
input_details = runner.get_input_details()
output_details = runner.get_output_details()

self.assertNotEmpty(input_details)
self.assertNotEmpty(output_details)
self.assertIn("input_1", input_details)

def test_output_match_original_interpreter(self):
# Run calibration interpreter
calib_runner = self._interpreter.get_signature_runner()
input_data = np.random.rand(1, 8).astype(np.float32)
calib_output = calib_runner(input_1=input_data)

# Run original interpreter
original_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
self._test_model_path
)
original_runner = original_interpreter.get_signature_runner()
original_output = original_runner(input_1=input_data)

# Compare
self.assertEqual(calib_output.keys(), original_output.keys())
for key in calib_output:
np.testing.assert_array_equal(calib_output[key], original_output[key])


if __name__ == "__main__":
absltest.main()
Loading
Loading