2020import pytest
2121import torch
2222from _test_utils .onnx .lib_test_models import SimpleMLP , export_as_onnx
23+ from onnxruntime .quantization .calibrate import CalibrationDataReader
2324
2425import modelopt .onnx .quantization as moq
2526
@@ -34,14 +35,15 @@ def assert_nodes_are_quantized(nodes):
3435 return True
3536
3637
37- @pytest .mark .parametrize ("high_precision_dtype" , ["fp32" , "fp16" , "bf16" ])
38- def test_int8 (tmp_path , high_precision_dtype ):
38+ def int8_test_helper (tmp_path , high_precision_dtype , ** kwargs ):
3939 model_torch = SimpleMLP ()
4040 input_tensor = torch .randn (2 , 16 , 16 )
4141
4242 onnx_path = os .path .join (tmp_path , "model.onnx" )
4343 export_as_onnx (model_torch , input_tensor , onnx_filename = onnx_path )
44- moq .quantize (onnx_path , quantize_mode = "int8" , high_precision_dtype = high_precision_dtype )
44+ moq .quantize (
45+ onnx_path , quantize_mode = "int8" , high_precision_dtype = high_precision_dtype , ** kwargs
46+ )
4547
4648 # Output model should be produced in the same tmp_path
4749 output_onnx_path = onnx_path .replace (".onnx" , ".quant.onnx" )
@@ -55,3 +57,45 @@ def test_int8(tmp_path, high_precision_dtype):
5557 # Check that all MatMul nodes are quantized
5658 mm_nodes = [n for n in graph .nodes if n .op == "MatMul" ]
5759 assert assert_nodes_are_quantized (mm_nodes )
60+
61+
62+ @pytest .mark .parametrize ("high_precision_dtype" , ["fp32" , "fp16" , "bf16" ])
63+ def test_int8 (tmp_path , high_precision_dtype ):
64+ int8_test_helper (tmp_path , high_precision_dtype )
65+
66+
67+ @pytest .mark .parametrize ("high_precision_dtype" , ["fp32" , "fp16" , "bf16" ])
68+ def test_int8_with_calibration_reader (tmp_path , high_precision_dtype ):
69+ input_tensor = torch .randn (2 , 16 , 16 )
70+
71+ # Calibration data comes from a custom data reader, enabling iterator based reading functionality
72+ class ExampleCalibrationDataReader (CalibrationDataReader ):
73+ def __init__ (self , input_data ):
74+ self .data_list = [{"input" : input_data .numpy ()}]
75+ self .iter = iter (self .data_list )
76+ self .get_first_calls = 0
77+ self .get_next_calls = 0
78+
79+ def get_next (self ):
80+ self .get_first_calls += 1
81+ return next (self .iter , None )
82+
83+ def get_first (self ):
84+ self .get_next_calls += 1
85+ return self .data_list [0 ]
86+
87+ def rewind (self ):
88+ self .iter = iter (self .data_list )
89+
90+ calibration_reader = ExampleCalibrationDataReader (input_tensor )
91+ int8_test_helper (tmp_path , high_precision_dtype , calibration_data_reader = calibration_reader )
92+ assert calibration_reader .get_first_calls > 0 or calibration_reader .get_next_calls > 0
93+
94+
95+ @pytest .mark .parametrize ("high_precision_dtype" , ["fp32" , "fp16" , "bf16" ])
96+ def test_int8_with_calibration_data (tmp_path , high_precision_dtype ):
97+ input_tensor = torch .randn (2 , 16 , 16 )
98+
99+ # test pre-allocated calibration data pathway
100+ calibration_data = {"input" : input_tensor .numpy ()}
101+ int8_test_helper (tmp_path , high_precision_dtype , calibration_data = calibration_data )
0 commit comments