diff --git a/ai_edge_quantizer/calibrator.py b/ai_edge_quantizer/calibrator.py index dc871ef6..73cbcfff 100644 --- a/ai_edge_quantizer/calibrator.py +++ b/ai_edge_quantizer/calibrator.py @@ -17,26 +17,147 @@ from collections.abc import Callable, Iterable import copy +import enum +import json from typing import Any, Union from absl import logging import numpy as np +import os from ai_edge_quantizer import algorithm_manager from ai_edge_quantizer import default_policy as policy from ai_edge_quantizer import qtyping +from ai_edge_quantizer import recipe from ai_edge_quantizer import recipe_manager from ai_edge_quantizer.utils import calibration_utils from ai_edge_quantizer.utils import progress_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils from ai_edge_quantizer.utils import tfl_interpreter_utils + +class CalibrationMode(enum.Enum): + INFERENCE = 1 + CALIBRATION = 2 + + _SignatureInput = dict[str, Any] # input_argument_name -> tensor_value. _SignatureOutput = dict[ str, np.ndarray ] # output_argument_name -> tensor_value. +class CalibrationInterpreter: + """A TFL interpreter-like interface for model calibration. + + This is a wrapper around Calibrator that replaces the TFL Interpreter to + enable calibration. If mode is CALIBRATION, it runs calibration, otherwise it + acts as a regular TFL interpreter for inference. When in CALIBRATION mode, + each invocation of a signature runner will update the calibration statistics + in self._calibrator. Calibrator is needed in both modes because it contains + tfl interpreter instance to run the model. + """ + + def __init__( + self, + model_path: str, + mode: CalibrationMode = CalibrationMode.INFERENCE, + ): + """Initializes the CalibrationInterpreter. + + Args: + model_path: The path to the TFLite model. + mode: The mode of the interpreter. If CALIBRATION, the interpreter will + preserve all tensors for calibration purposes. + """ + self._calibrator = Calibrator( + model_path, + interpreter_preserve_all_tensors=(mode == CalibrationMode.CALIBRATION), + ) + self._mode = mode + + def get_signature_runner(self, signature_key: str | None = None): + """Returns the signature runner.""" + return CalibrationSignatureRunner( + self._calibrator, signature_key, self._mode + ) + + def get_calibration_results(self): + """Returns the calibration results.""" + if self._mode == CalibrationMode.INFERENCE: + raise ValueError( + "Calibration results are not available in INFERENCE mode." + ) + return self._calibrator.get_model_qsvs() + + def save_calibration_result(self, output_path: str): + """Saves the calibration results.""" + if self._mode == CalibrationMode.INFERENCE: + raise ValueError( + "Calibration results are not available in INFERENCE mode." + ) + self._calibrator.save_calibration_result(output_path) + + def get_signature_list(self) -> list[str]: + """Returns the signature list.""" + return self._calibrator.get_signature_list() + + +class CalibrationSignatureRunner: + """Wrapper around TFL signature runner to enable calibration.""" + + def __init__( + self, + calibrator_obj: "Calibrator", + signature_key: str | None = None, + mode: CalibrationMode = CalibrationMode.INFERENCE, + quantization_recipe: recipe_manager.ModelQuantizationRecipe = recipe.static_wi8_ai8(), + ): + """Initializes the CalibrationSignatureRunner. + + Args: + calibrator_obj: The Calibrator instance. + signature_key: The key of the signature to run. If None, the default + signature is used. + mode: The mode of the runner. If CALIBRATION, invoking the runner will + update 'calibrator_obj' with new quantization statistics values. If + INFERENCE, the runner behaves like a standard signature runner. + quantization_recipe: The quantization recipe to use for calibration. + Defaults to static_wi8_ai8. + """ + self._calibrator = calibrator_obj + self._signature_key = signature_key + self._mode = mode + self._recipe_manager = recipe_manager.RecipeManager() + self._recipe_manager.load_quantization_recipe(quantization_recipe) + self._signature_runner = ( + self._calibrator._tfl_interpreter.get_signature_runner( + self._signature_key + ) + ) + + def __call__(self, **kwargs): + if self._mode == CalibrationMode.INFERENCE: + return self._signature_runner(**kwargs) + self._calibrator.calibrate( + calibration_dataset={self._signature_key: [kwargs]}, + model_recipe_manager=self._recipe_manager, + cache_output=True, + ) + outputs = self._calibrator.get_cached_output() + assert len(outputs) == 1 + self._calibrator.clear_cached_output() + return outputs[0] + + def get_input_details(self): + """Returns the input details of the model.""" + return self._signature_runner.get_input_details() + + def get_output_details(self): + """Returns the output details of the model.""" + return self._signature_runner.get_output_details() + + class Calibrator: """Calibrator for TFLite model.""" @@ -44,11 +165,15 @@ def __init__( self, float_tflite: Union[str, bytes], num_threads: int = 16, + interpreter_preserve_all_tensors: bool = True, ): self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite) self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter( - float_tflite, use_xnnpack=True, num_threads=num_threads + float_tflite, + use_xnnpack=True, + num_threads=num_threads, + preserve_all_tensors=interpreter_preserve_all_tensors, ) # Tensor name to tensor content. self._tensor_content_map: dict[str, Any] = {} @@ -140,7 +265,7 @@ def calibrate( disable=total_ops < 1000, # We skip the progress bar for small models and small datasets. ) as pbar: - # TODO: b/329322226 - Enable parallel calibration. + # TODO: b/329322226 - Enable parallel calibration. for signature_key, dataset in calibration_dataset.items(): # Step0: get subgraph index. subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index( @@ -242,13 +367,29 @@ def reset_model_qsvs(self) -> None: """Reset the model qsvs.""" self._model_qsvs = {} - def load_model_qsvs(self, model_qsvs: dict[str, qtyping.QSV]) -> None: + def load_model_qsvs( + self, model_qsvs: Union[str, dict[str, qtyping.QSV]] + ) -> None: """Load the model qsvs. Args: - model_qsvs: A dictionary of tensor name to QSV. + model_qsvs: A dictionary of tensor name to QSV or a path to a JSON file + that contains the model qsvs (i.e., from save_calibration_result). """ - self._model_qsvs = copy.deepcopy(model_qsvs) + + if isinstance(model_qsvs, str): + self._model_qsvs = calibration_utils.load_calibration_results(model_qsvs) + else: + self._model_qsvs = copy.deepcopy(model_qsvs) + + def save_calibration_result(self, file_path: str) -> None: + """Saves the calibration result to a json file.""" + with open(file_path, "w") as f: + json.dump(self._model_qsvs, f, cls=calibration_utils.NumpyEncoder) + + def get_signature_list(self) -> list[str]: + """Get the signature list of the model.""" + return self._tfl_interpreter.get_signature_list() def _update_qsvs( self, diff --git a/ai_edge_quantizer/calibrator_test.py b/ai_edge_quantizer/calibrator_test.py index f5488662..be799331 100644 --- a/ai_edge_quantizer/calibrator_test.py +++ b/ai_edge_quantizer/calibrator_test.py @@ -177,6 +177,35 @@ def test_calibrate_reshape_with_empty_shape_success(self): test_calibrator.calibrate(calib_data, self._recipe_manager) self.assertNotEmpty(test_calibrator.get_model_qsvs()) + def test_save_and_load_calibration_result(self): + # Setup some QSV + sample_qsv = { + "serving_default_input_1:0": { + "min": np.array([-10.0]), + "max": np.array([8.0]), + } + } + self._calibrator.load_model_qsvs(sample_qsv) + + # Save + temp_file = self.create_tempfile().full_path + self._calibrator.save_calibration_result(temp_file) + + # Reset + self._calibrator.reset_model_qsvs() + self.assertEmpty(self._calibrator.get_model_qsvs()) + + # Load + self._calibrator.load_model_qsvs(temp_file) + + # Verify + model_tensor_qsvs = self._calibrator.get_model_qsvs() + self.assertLen(model_tensor_qsvs, 1) + self.assertIn("serving_default_input_1:0", model_tensor_qsvs) + input_qsv = model_tensor_qsvs["serving_default_input_1:0"] + self.assertSequenceAlmostEqual(input_qsv["min"], [-10.0]) + self.assertSequenceAlmostEqual(input_qsv["max"], [8.0]) + class CalibratorAlreadyQuantizedModelTest(absltest.TestCase): @@ -238,5 +267,100 @@ def test_toy_gemma2_calibration_success(self): self.assertLen(calib.get_model_qsvs(), 202) +class CalibrationInterpreterTest(absltest.TestCase): + + def setUp(self): + super().setUp() + np.random.seed(0) + self._test_model_path = str( + pathlib.Path(TEST_DATA_PREFIX_PATH) / "tests/models/single_fc.tflite" + ) + self._interpreter = calibrator.CalibrationInterpreter( + self._test_model_path, mode=calibrator.CalibrationMode.CALIBRATION + ) + + def test_initialization(self): + self.assertIsInstance(self._interpreter, calibrator.CalibrationInterpreter) + + def test_get_signature_runner(self): + runner = self._interpreter.get_signature_runner() + self.assertIsInstance(runner, calibrator.CalibrationSignatureRunner) + + def test_calibration_flow(self): + runner = self._interpreter.get_signature_runner() + + # Run inference which triggers calibration + input_data = np.random.rand(1, 8).astype(np.float32) + output = runner(input_1=input_data) + + # Check results + qsvs = self._interpreter.get_calibration_results() + self.assertNotEmpty(qsvs) + + # Verify input tensor qsv + self.assertIn("serving_default_input_1:0", qsvs) + self.assertIsNotNone(output) + + def test_disable_calibration(self): + interpreter = calibrator.CalibrationInterpreter( + self._test_model_path, mode=calibrator.CalibrationMode.INFERENCE + ) + runner = interpreter.get_signature_runner() + + input_data = np.random.rand(1, 8).astype(np.float32) + output = runner(input_1=input_data) + + with self.assertRaisesRegex( + ValueError, "Calibration results are not available in INFERENCE mode." + ): + interpreter.get_calibration_results() + self.assertIsNotNone(output) + + def test_save_calibration_result(self): + runner = self._interpreter.get_signature_runner() + input_data = np.random.rand(1, 8).astype(np.float32) + runner(input_1=input_data) + + temp_file = self.create_tempfile().full_path + self._interpreter.save_calibration_result(temp_file) + + # Verify file exists and has content + with open(temp_file, "r") as f: + content = f.read() + self.assertNotEmpty(content) + + def test_get_signature_list(self): + signatures = self._interpreter.get_signature_list() + self.assertNotEmpty(signatures) + self.assertIn("serving_default", signatures) + + def test_runner_details(self): + runner = self._interpreter.get_signature_runner() + input_details = runner.get_input_details() + output_details = runner.get_output_details() + + self.assertNotEmpty(input_details) + self.assertNotEmpty(output_details) + self.assertIn("input_1", input_details) + + def test_output_match_original_interpreter(self): + # Run calibration interpreter + calib_runner = self._interpreter.get_signature_runner() + input_data = np.random.rand(1, 8).astype(np.float32) + calib_output = calib_runner(input_1=input_data) + + # Run original interpreter + original_interpreter = tfl_interpreter_utils.create_tfl_interpreter( + self._test_model_path + ) + original_runner = original_interpreter.get_signature_runner() + original_output = original_runner(input_1=input_data) + + # Compare + self.assertEqual(calib_output.keys(), original_output.keys()) + for key in calib_output: + np.testing.assert_array_equal(calib_output[key], original_output[key]) + + if __name__ == "__main__": absltest.main() diff --git a/ai_edge_quantizer/model_modifier.py b/ai_edge_quantizer/model_modifier.py index 41cd11c8..f492cd99 100644 --- a/ai_edge_quantizer/model_modifier.py +++ b/ai_edge_quantizer/model_modifier.py @@ -32,6 +32,7 @@ _DEQUANT_SUFFIX = "_dequant" +_QUANT_SUFFIX = "_quantized" class ModelModifier: @@ -119,26 +120,41 @@ def modify_model( serialized_quantized_model = serialize_fun(quantized_model) # Update signature defs if dequant is inserted before output. - if self._has_dequant_before_output(instructions): - quantized_model = self._update_signature_defs_for_dequant_output( - quantized_model, serialized_quantized_model + if self._has_transform_before_output( + instructions, qtyping.QuantTransformation.ADD_DEQUANTIZE + ): + quantized_model = self._update_signature_defs( + quantized_model, serialized_quantized_model, _DEQUANT_SUFFIX + ) + serialized_quantized_model = serialize_fun(quantized_model) + + # Update signature defs if quant is inserted before output. + if self._has_transform_before_output( + instructions, qtyping.QuantTransformation.ADD_QUANTIZE + ): + quantized_model = self._update_signature_defs( + quantized_model, serialized_quantized_model, _QUANT_SUFFIX ) serialized_quantized_model = serialize_fun(quantized_model) return serialized_quantized_model - def _update_signature_defs_for_dequant_output( - self, model: schema_py_generated.ModelT, serialized_model: bytearray - ): + def _update_signature_defs( + self, + model: schema_py_generated.ModelT, + serialized_model: bytearray, + suffix: str, + ) -> schema_py_generated.ModelT: """Updates the signature definitions in the model. - This function is called when a dequantize operation is inserted before - an output tensor. It updates the tensor index in the signature - definitions to point to the newly inserted dequantize output tensor. + This function is called when a transformation (quantize or dequantize) + is inserted before an output tensor. It updates the tensor index in the + signature definitions to point to the newly inserted output tensor. Args: model: The TFlite ModelT object. serialized_model: The serialized bytearray of the TFlite model. + suffix: The suffix to append to the tensor name. Returns: The updated TFlite ModelT object. @@ -164,7 +180,7 @@ def _update_signature_defs_for_dequant_output( logging.info("\tOutput tensor = `%s`", tensor_name) for signature_name, tensor_details in output_details.items(): - if tensor_details["name"] + _DEQUANT_SUFFIX == tensor_name: + if tensor_details["name"] + suffix == tensor_name: logging.info( "\t\tfound tensor mapping: `%s`->`%s` for signature name: `%s`", tensor_details["name"], @@ -184,18 +200,19 @@ def _update_signature_defs_for_dequant_output( return model - def _has_dequant_before_output( - self, instructions: dict[str, qtyping.TensorTransformationInsts] + def _has_transform_before_output( + self, + instructions: dict[str, qtyping.TensorTransformationInsts], + transformation: qtyping.QuantTransformation, ) -> bool: - """Check if the model has dequant insert to output.""" + """Check if the model has transformation insert to output.""" for tensor_name, tensor_trans_insts in instructions.items(): for instr in tensor_trans_insts.instructions: - if ( - qtyping.QuantTransformation.ADD_DEQUANTIZE == instr.transformation - and instr.consumers == [-1] - ): + if transformation == instr.transformation and instr.consumers == [-1]: logging.info( - "Found dequant insert to output for tensor: %s", tensor_name + "Found %s insert to output for tensor: %s", + transformation, + tensor_name, ) return True return False diff --git a/ai_edge_quantizer/model_modifier_test.py b/ai_edge_quantizer/model_modifier_test.py index 483784fd..1853a375 100644 --- a/ai_edge_quantizer/model_modifier_test.py +++ b/ai_edge_quantizer/model_modifier_test.py @@ -125,7 +125,7 @@ def test_modify_model_peak_memory_usage_in_acceptable_range(self): loosen_mem_use_factor = 4.5 self.assertLess(mem_peak / len(self._model_content), loosen_mem_use_factor) - def test_has_dequant_before_output_true(self): + def test_has_transform_before_output_true_dequant(self): instructions = { 'tensor1': qtyping.TensorTransformationInsts( 'tensor1', @@ -141,10 +141,12 @@ def test_has_dequant_before_output_true(self): ) } self.assertTrue( - self._model_modifier._has_dequant_before_output(instructions) + self._model_modifier._has_transform_before_output( + instructions, qtyping.QuantTransformation.ADD_DEQUANTIZE + ) ) - def test_has_dequant_before_output_false(self): + def test_has_transform_before_output_false_dequant(self): instructions = { 'tensor1': qtyping.TensorTransformationInsts( 'tensor1', @@ -160,7 +162,51 @@ def test_has_dequant_before_output_false(self): ) } self.assertFalse( - self._model_modifier._has_dequant_before_output(instructions) + self._model_modifier._has_transform_before_output( + instructions, qtyping.QuantTransformation.ADD_DEQUANTIZE + ) + ) + + def test_has_transform_before_output_true_quant(self): + instructions = { + 'tensor1': qtyping.TensorTransformationInsts( + 'tensor1', + 0, + instructions=[ + qtyping.TransformationInst( + transformation=qtyping.QuantTransformation.ADD_QUANTIZE, + tensor_id=0, + producer=0, + consumers=[-1], + ) + ], + ) + } + self.assertTrue( + self._model_modifier._has_transform_before_output( + instructions, qtyping.QuantTransformation.ADD_QUANTIZE + ) + ) + + def test_has_transform_before_output_false_quant(self): + instructions = { + 'tensor1': qtyping.TensorTransformationInsts( + 'tensor1', + 0, + instructions=[ + qtyping.TransformationInst( + transformation=qtyping.QuantTransformation.ADD_QUANTIZE, + tensor_id=0, + producer=0, + consumers=[1], + ) + ], + ) + } + self.assertFalse( + self._model_modifier._has_transform_before_output( + instructions, qtyping.QuantTransformation.ADD_QUANTIZE + ) ) def test_pad_bytearray(self): @@ -190,17 +236,25 @@ def setUp(self): ) self._model_modifier = model_modifier.ModelModifier(self._model_content) - def test_update_signature_defs_for_dequant_output_succeeds(self): + def test_update_signature_defs_succeeds_dequant(self): # This is a simplified test that only checks if the function runs without # crashing and returns a model. A more thorough test with a model # with a known signature was added in `quantizer_test`. model_bytearray = flatbuffer_utils.read_model_from_bytearray( self._model_content ) - updated_model = ( - self._model_modifier._update_signature_defs_for_dequant_output( - model_bytearray, bytearray(self._model_content) - ) + updated_model = self._model_modifier._update_signature_defs( + model_bytearray, bytearray(self._model_content), '_dequant' + ) + self.assertIsNotNone(updated_model) + + def test_update_signature_defs_succeeds_quant(self): + # This checks if the function runs without crashing and returns a model. + model_bytearray = flatbuffer_utils.read_model_from_bytearray( + self._model_content + ) + updated_model = self._model_modifier._update_signature_defs( + model_bytearray, bytearray(self._model_content), '_quantized' ) self.assertIsNotNone(updated_model) diff --git a/ai_edge_quantizer/utils/calibration_utils.py b/ai_edge_quantizer/utils/calibration_utils.py index ab4ef062..d6430acd 100644 --- a/ai_edge_quantizer/utils/calibration_utils.py +++ b/ai_edge_quantizer/utils/calibration_utils.py @@ -16,10 +16,12 @@ """Utilities for model calibration.""" import copy +import json from typing import Any, Union import numpy as np +import os from ai_edge_litert.tools import flatbuffer_utils from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.utils import common_utils @@ -27,7 +29,6 @@ from ai_edge_quantizer.utils import tfl_flatbuffer_utils from ai_edge_quantizer.utils import tfl_interpreter_utils - _SignatureInput = dict[str, Any] _OpQuantConstraint = common_utils.OpQuantConstraint _SignatureData = dict[ @@ -35,6 +36,19 @@ ] # signature_key -> list of signature_names. +class NumpyEncoder(json.JSONEncoder): + """JSON Encoder for Numpy types.""" + + def default(self, o): + if isinstance(o, np.integer): + return int(o) + elif isinstance(o, np.floating): + return float(o) + elif isinstance(o, np.ndarray): + return o.tolist() + return super().default(o) + + def _update_moving_average( smoothing_factor: Union[np.ndarray, float], w: np.ndarray, @@ -101,6 +115,27 @@ def min_max_update(qsv: qtyping.QSV, new_qsv: qtyping.QSV) -> qtyping.QSV: return updated_qsv +def load_calibration_results(file_path: str) -> dict[str, qtyping.QSV]: + """Loads calibration results from a file. + + Args: + file_path: Path to the calibration results file. + + Returns: + A dictionary of tensor name to QSV. + """ + with open(file_path) as json_file: + model_qsvs = json.load(json_file) + + # Convert lists back to numpy arrays + for _, qsv in model_qsvs.items(): + if "min" in qsv: + qsv["min"] = np.array(qsv["min"]) + if "max" in qsv: + qsv["max"] = np.array(qsv["max"]) + return model_qsvs + + def _find_overall_min_max( qsv: qtyping.QSV, tensor_names: list[str] ) -> tuple[np.ndarray, np.ndarray]: diff --git a/ai_edge_quantizer/utils/calibration_utils_test.py b/ai_edge_quantizer/utils/calibration_utils_test.py index 8a6f6890..4608ace6 100644 --- a/ai_edge_quantizer/utils/calibration_utils_test.py +++ b/ai_edge_quantizer/utils/calibration_utils_test.py @@ -128,6 +128,14 @@ def test_update_tensor_qsv_min_max(self, old_qsv, new_qsv, expected_qsv): self.assertEqual(updated_qsv["min"], expected_qsv["min"]) self.assertEqual(updated_qsv["max"], expected_qsv["max"]) + def test_load_calibration_results(self): + temp_file = self.create_tempfile() + temp_file.write_text('{"tensor1": {"min": [-1.0], "max": [1.0]}}') + results = calibration_utils.load_calibration_results(temp_file.full_path) + self.assertIn("tensor1", results) + self.assertTrue(np.array_equal(results["tensor1"]["min"], [-1.0])) + self.assertTrue(np.array_equal(results["tensor1"]["max"], [1.0])) + def test_calibration_utils_init_fails(self): model_path = "non_existent_model.tflite" with self.assertRaisesWithPredicateMatch(