Skip to content

Commit 4d7d297

Browse files
committed
Add support for single npz file with multiple samples
Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent 4f4558a commit 4d7d297

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

modelopt/onnx/autocast/referencerunner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import onnx
3131

3232
from modelopt.onnx.autocast.logging_config import configure_logging, logger
33+
from modelopt.onnx.quantization.calib_utils import CalibrationDataProvider
3334
from modelopt.onnx.quantization.ort_utils import _prepare_ep_list
3435

3536
configure_logging()
@@ -70,7 +71,11 @@ def _load_inputs_from_json(self, input_data_path):
7071

7172
def _load_inputs_from_npz(self, input_data_path):
7273
"""Load inputs from NPZ format."""
73-
return [np.load(input_data_path)]
74+
calib_data = np.load(input_data_path)
75+
76+
# Wrap data into a CalibDataProvider to support a single NPZ file containing data from multiple batches
77+
data_loader = {key: calib_data[key] for key in calib_data.files}
78+
return CalibrationDataProvider(self.model, data_loader).calibration_data_list
7479

7580
def _validate_inputs(self, data_loader):
7681
"""Validate that input names and shapes match the model."""

modelopt/onnx/quantization/calib_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class CalibrationDataProvider(CalibrationDataReader):
3838

3939
def __init__(
4040
self,
41-
onnx_path: str,
41+
onnx_path: str | onnx.ModelProto,
4242
calibration_data: CalibrationDataType,
4343
calibration_shapes: str | None = None,
4444
):
@@ -58,7 +58,7 @@ def __init__(
5858
logger.info("Setting up CalibrationDataProvider for calibration")
5959
# Tensor data is not required to generate the calibration data
6060
# So even if the model has external data, we don't need to load them here
61-
onnx_model = onnx.load(onnx_path)
61+
onnx_model = onnx.load(onnx_path) if isinstance(onnx_path, str) else onnx_path
6262
input_names = get_input_names(onnx_model)
6363
input_shapes = {} if calibration_shapes is None else parse_shapes_spec(calibration_shapes)
6464
inferred_input_shapes = get_input_shapes(onnx_model)

0 commit comments

Comments
 (0)