Skip to content

Commit 8f79948

Browse files
committed
Expose calibration_data_reader to the public interface to allow users to create their own iterator
Signed-off-by: dmoodie <dmoodie@nvidia.com>
1 parent f22f4f5 commit 8f79948

3 files changed

Lines changed: 64 additions & 10 deletions

File tree

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
NVIDIA Model Optimizer Changelog
22
================================
3+
0.44 (2026-04-xx)
4+
5+
**New Features**
6+
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
7+
38

49
0.44 (2026-05-xx)
510
^^^^^^^^^^^^^^^^^

modelopt/onnx/quantization/quantize.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import onnx.onnx_cpp2py_export.checker as C
4444
import onnx_graphsurgeon as gs
4545
import onnxslim
46+
from onnxruntime.quantization.calibrate import CalibrationDataReader
4647

4748
from modelopt.onnx.logging_config import configure_logging, logger
4849
from modelopt.onnx.op_types import is_data_dependent_shape_op
@@ -305,6 +306,7 @@ def quantize(
305306
calibration_data: CalibrationDataType = None,
306307
calibration_method: str | None = None,
307308
calibration_cache_path: str | None = None,
309+
calibration_data_reader: CalibrationDataReader | None = None,
308310
calibration_shapes: str | None = None,
309311
calibration_eps: list[str] = ["cpu", "cuda:0", "trt"],
310312
override_shapes: str | None = None,
@@ -361,6 +363,8 @@ def quantize(
361363
and int4: {'awq_clip' (default), 'awq_lite', 'awq_full', 'rtn_dq'}.
362364
calibration_cache_path:
363365
Path to pre-calculated activation tensor ranges, also known as calibration cache.
366+
calibration_data_reader:
367+
Instance of a CalibrationDataReader object to provide calibration data.
364368
calibration_shapes:
365369
Input shapes used for calibration process.
366370
It should be provided as a string representing the shape of each input tensors for one calibration step.
@@ -571,13 +575,14 @@ def quantize(
571575
)
572576
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]
573577

574-
# Use random scales if calibration data is not supplied
575-
if calibration_data is None:
576-
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
577-
else:
578-
calibration_data_reader = CalibrationDataProvider(
579-
onnx_path, calibration_data, calibration_shapes
580-
)
578+
if calibration_data_reader is None:
579+
# Use random scales if calibration data is not supplied
580+
if calibration_data is None:
581+
calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
582+
else:
583+
calibration_data_reader = CalibrationDataProvider(
584+
onnx_path, calibration_data, calibration_shapes
585+
)
581586

582587
nodes_to_quantize = nodes_to_quantize or []
583588
nodes_to_exclude = nodes_to_exclude or []

tests/unit/onnx/quantization/test_quantize_int8.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytest
2121
import torch
2222
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx
23+
from onnxruntime.quantization.calibrate import CalibrationDataReader
2324

2425
import modelopt.onnx.quantization as moq
2526

@@ -34,14 +35,15 @@ def assert_nodes_are_quantized(nodes):
3435
return True
3536

3637

37-
@pytest.mark.parametrize("high_precision_dtype", ["fp32", "fp16", "bf16"])
38-
def test_int8(tmp_path, high_precision_dtype):
38+
def int8_test_helper(tmp_path, high_precision_dtype, **kwargs):
3939
model_torch = SimpleMLP()
4040
input_tensor = torch.randn(2, 16, 16)
4141

4242
onnx_path = os.path.join(tmp_path, "model.onnx")
4343
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
44-
moq.quantize(onnx_path, quantize_mode="int8", high_precision_dtype=high_precision_dtype)
44+
moq.quantize(
45+
onnx_path, quantize_mode="int8", high_precision_dtype=high_precision_dtype, **kwargs
46+
)
4547

4648
# Output model should be produced in the same tmp_path
4749
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
@@ -55,3 +57,45 @@ def test_int8(tmp_path, high_precision_dtype):
5557
# Check that all MatMul nodes are quantized
5658
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
5759
assert assert_nodes_are_quantized(mm_nodes)
60+
61+
62+
@pytest.mark.parametrize("high_precision_dtype", ["fp32", "fp16", "bf16"])
63+
def test_int8(tmp_path, high_precision_dtype):
64+
int8_test_helper(tmp_path, high_precision_dtype)
65+
66+
67+
@pytest.mark.parametrize("high_precision_dtype", ["fp32", "fp16", "bf16"])
68+
def test_int8_with_calibration_reader(tmp_path, high_precision_dtype):
69+
input_tensor = torch.randn(2, 16, 16)
70+
71+
# Calibration data comes from a custom data reader, enabling iterator based reading functionality
72+
class ExampleCalibrationDataReader(CalibrationDataReader):
73+
def __init__(self, input_data):
74+
self.data_list = [{"input": input_data.numpy()}]
75+
self.iter = iter(self.data_list)
76+
self.get_first_calls = 0
77+
self.get_next_calls = 0
78+
79+
def get_next(self):
80+
self.get_first_calls += 1
81+
return next(self.iter, None)
82+
83+
def get_first(self):
84+
self.get_next_calls += 1
85+
return self.data_list[0]
86+
87+
def rewind(self):
88+
self.iter = iter(self.data_list)
89+
90+
calibration_reader = ExampleCalibrationDataReader(input_tensor)
91+
int8_test_helper(tmp_path, high_precision_dtype, calibration_data_reader=calibration_reader)
92+
assert calibration_reader.get_first_calls > 0 or calibration_reader.get_next_calls > 0
93+
94+
95+
@pytest.mark.parametrize("high_precision_dtype", ["fp32", "fp16", "bf16"])
96+
def test_int8_with_calibration_data(tmp_path, high_precision_dtype):
97+
input_tensor = torch.randn(2, 16, 16)
98+
99+
# test pre-allocated calibration data pathway
100+
calibration_data = {"input": input_tensor.numpy()}
101+
int8_test_helper(tmp_path, high_precision_dtype, calibration_data=calibration_data)

0 commit comments

Comments
 (0)