forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
369 lines (350 loc) · 13.5 KB
/
__main__.py
File metadata and controls
369 lines (350 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command-line entrypoint for ONNX PTQ."""
import argparse
import os
import numpy as np
from modelopt.onnx.quantization.quantize import quantize
__all__ = ["main"]
def validate_file_size(file_path: str, max_size_bytes: int) -> None:
"""Validate that a file exists and does not exceed the maximum allowed size.
Args:
file_path: Path to the file to validate
max_size_bytes: Maximum allowed file size in bytes
Raises:
FileNotFoundError: If the file does not exist
ValueError: If the file exceeds the maximum allowed size
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
file_size = os.path.getsize(file_path)
if file_size > max_size_bytes:
max_size_gb = max_size_bytes / (1024 * 1024 * 1024)
actual_size_gb = file_size / (1024 * 1024 * 1024)
raise ValueError(
f"File size validation failed: {file_path} ({actual_size_gb:.2f}GB) exceeds "
f"maximum allowed size of {max_size_gb:.2f}GB. This limit helps prevent potential "
f"denial-of-service attacks."
)
def get_parser() -> argparse.ArgumentParser:
"""Get the argument parser for ONNX PTQ."""
argparser = argparse.ArgumentParser("python -m modelopt.onnx.quantization")
group = argparser.add_mutually_exclusive_group(required=False)
argparser.add_argument(
"--onnx_path", required=True, type=str, help="Input onnx model without Q/DQ nodes."
)
argparser.add_argument(
"--quantize_mode",
type=str,
choices=["fp8", "int8", "int4"],
default="int8",
help="Quantization mode for the given ONNX model.",
)
argparser.add_argument(
"--calibration_method",
type=str,
choices=["max", "entropy", "awq_clip", "rtn_dq"],
help=(
"Calibration method choices for int8/fp8: {entropy (default), max}, "
"int4: {awq_clip (default), rtn_dq}."
),
)
group.add_argument(
"--calibration_data_path",
type=str,
help="Calibration data in npz/npy format. If None, random data for calibration will be used.",
)
group.add_argument(
"--trust_calibration_data",
action="store_true",
help="If True, trust the calibration data and allow pickle deserialization.",
)
group.add_argument(
"--calibration_cache_path",
type=str,
help="Pre-calculated activation tensor scaling factors aka calibration cache path.",
)
argparser.add_argument(
"--calibration_shapes",
type=str,
required=False,
help=(
"Optional model input shapes for calibration."
"Users should provide the shapes specifically if the model has non-batch dynamic dimensions."
"Example input shapes spec: input0:1x3x256x256,input1:1x3x128x128"
),
)
argparser.add_argument(
"--calibration_eps",
type=str,
default=["cpu", "cuda:0", "trt"],
nargs="+",
help=(
"Priority order for the execution providers (EP) to calibrate the model. "
"Any subset of ['trt', 'cuda:x', dml:x, 'cpu'], where 'x' is the device id."
"If a custom op is detected in the model, 'trt' will automatically be added to the EP list."
),
)
argparser.add_argument(
"--override_shapes",
type=str,
required=False,
help=(
"Override model input shapes with static shapes."
"Example input shapes spec: input0:1x3x256x256,input1:1x3x128x128"
),
)
argparser.add_argument(
"--op_types_to_quantize",
type=str,
nargs="+",
help="A space-separated list of node types to quantize.",
)
argparser.add_argument(
"--op_types_to_exclude",
type=str,
nargs="+",
help="A space-separated list of node types to exclude from quantization.",
)
argparser.add_argument(
"--op_types_to_exclude_fp16",
type=str,
nargs="+",
help=(
"A space-separated list of node types to exclude from FP16/BF16 conversion. "
"Relevant when --high_precision_dtype is 'fp16' or 'bf16'."
),
)
argparser.add_argument(
"--nodes_to_quantize",
type=str,
nargs="+",
help="A space-separated list of node names to quantize. Regular expressions are supported.",
)
argparser.add_argument(
"--nodes_to_exclude",
type=str,
nargs="+",
help="A space-separated list of node names to exclude from quantization. Regular expressions are supported.",
)
argparser.add_argument(
"--use_external_data_format",
action="store_true",
help=(
"If True or model size is larger than 2GB, "
"<MODEL_NAME>.onnx_data will be used to write weights and constants."
),
)
argparser.add_argument(
"--keep_intermediate_files",
action="store_true",
help=(
"If True, keep the files generated during the ONNX models' conversion/calibration. "
"Otherwise, only the converted ONNX file is kept for the user."
),
)
argparser.add_argument(
"--output_path",
type=str,
help=(
"Output filename to save the converted ONNX model. If None, save it in the same dir as "
"the original ONNX model with an appropriate suffix."
),
)
argparser.add_argument(
"--log_level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "debug", "info", "warning", "error"],
default="INFO",
help="Set the logging level for the quantization process.",
)
argparser.add_argument(
"--log_file",
type=str,
default=None,
help="Path to the log file for the quantization process.",
)
argparser.add_argument(
"--trt_plugins",
type=str,
default=None,
nargs="+",
help=(
"A space-separated list with the custom TensorRT plugin library paths in .so format (compiled shared "
"library). If this is not None, the TensorrtExecutionProvider is invoked, so make sure that the TensorRT "
"libraries are in the PATH or LD_LIBRARY_PATH variables."
),
)
argparser.add_argument(
"--trt_plugins_precision",
type=str,
default=None,
nargs="+",
help=(
"A space-separated list indicating the precision for each custom op. "
"Each item should have the format <op_type>:<precision> (all inputs and outputs have the same precision) "
"or <op_type>:[<inp1_precision>,<inp2_precision>,...]:[<out1_precision>,<out2_precision>,...] "
"(inputs and outputs can have different precisions), where precision can be fp32 (default), "
"fp16, int8, or fp8. Note that int8/fp8 should be the same as the quantization mode. "
"For example: op_type_1:fp16 op_type_2:[int8,fp32]:[int8]."
),
)
argparser.add_argument(
"--high_precision_dtype",
type=str,
default="fp16",
choices=["fp32", "fp16", "bf16"],
help=(
"High precision data type of the output model. If the input model is of dtype fp32, "
"it will be converted to fp16 dtype by default."
),
)
argparser.add_argument(
"--mha_accumulation_dtype",
type=str,
default="fp16",
help=(
"Accumulation dtype of MHA. This flag will only take effect when mha_accumulation_dtype == 'fp32' "
"and quantize_mode == 'fp8'. One of ['fp32', 'fp16']"
),
)
argparser.add_argument(
"--disable_mha_qdq",
action="store_true",
help="If True, Q/DQ will not be added to MatMuls in MHA pattern.",
)
argparser.add_argument(
"--dq_only",
action="store_true",
help=(
"If True, FP32/FP16 weights will be converted to INT8/FP8 weights. Q nodes will get removed from the "
"weights and have only DQ nodes with those converted INT8/FP8 weights in the output model."
),
)
argparser.add_argument(
"--use_zero_point",
type=bool,
default=False,
help=(
"If True, zero-point based quantization will be used - currently, applicable for awq_lite algorithm."
),
)
argparser.add_argument(
"--passes",
type=str,
choices=["concat_elimination"],
default=["concat_elimination"],
nargs="+",
help=(
"A space-separated list of optimization passes name, if set, appropriate pre/post-processing passes will "
"be invoked."
),
)
argparser.add_argument(
"--simplify",
action="store_true",
help="If True, the given ONNX model will be simplified before quantization is performed.",
)
argparser.add_argument(
"--calibrate_per_node",
action="store_true",
help=(
"If set, performs calibration per node instead of running inference over the entire network. "
"Useful for reducing memory consumption during large model inference."
),
)
argparser.add_argument(
"--direct_io_types",
action="store_true",
help=(
"If True, the I/O types in the quantized ONNX model will be modified to be lower precision whenever "
"possible. Else, they will match the I/O types in the given ONNX model. "
"The currently supported precisions are {fp16, int8, fp8}."
),
)
argparser.add_argument(
"--opset",
type=int,
help=(
"Target ONNX opset version for the quantized model. If not specified, uses default minimum opset "
"(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased "
"if certain operations require a higher version."
),
)
return argparser
def main():
"""Command-line entrypoint for ONNX PTQ."""
args = get_parser().parse_args()
# Security: Validate onnx model size is under 2GB by default
if not args.use_external_data_format:
try:
validate_file_size(args.onnx_path, 2 * (1024**3))
except ValueError as e:
raise ValueError(
"Onnx model size larger than 2GB. Please set --use_external_data_format flag to bypass this validation."
) from e
calibration_data = None
if args.calibration_data_path:
# Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks
try:
calibration_data = np.load(
args.calibration_data_path, allow_pickle=args.trust_calibration_data
)
if args.calibration_data_path.endswith(".npz"):
# Convert the NpzFile object to a Python dictionary
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
except ValueError as e:
if "allow_pickle" in str(e) and not args.trust_calibration_data:
raise ValueError(
"Calibration data file contains pickled objects which pose a security risk. "
"For trusted sources, you may enable pickle deserialization by setting the "
"--trust_calibration_data flag."
) from e
else:
raise
quantize(
args.onnx_path,
quantize_mode=args.quantize_mode,
calibration_data=calibration_data,
calibration_method=args.calibration_method,
calibration_cache_path=args.calibration_cache_path,
calibration_shapes=args.calibration_shapes,
calibration_eps=args.calibration_eps,
override_shapes=args.override_shapes,
op_types_to_quantize=args.op_types_to_quantize,
op_types_to_exclude=args.op_types_to_exclude,
op_types_to_exclude_fp16=args.op_types_to_exclude_fp16,
nodes_to_quantize=args.nodes_to_quantize,
nodes_to_exclude=args.nodes_to_exclude,
use_external_data_format=args.use_external_data_format,
keep_intermediate_files=args.keep_intermediate_files,
output_path=args.output_path,
log_level=args.log_level,
log_file=args.log_file,
trt_plugins=args.trt_plugins,
trt_plugins_precision=args.trt_plugins_precision,
high_precision_dtype=args.high_precision_dtype,
mha_accumulation_dtype=args.mha_accumulation_dtype,
disable_mha_qdq=args.disable_mha_qdq,
dq_only=args.dq_only,
use_zero_point=args.use_zero_point,
passes=args.passes,
simplify=args.simplify,
calibrate_per_node=args.calibrate_per_node,
direct_io_types=args.direct_io_types,
opset=args.opset,
)
if __name__ == "__main__":
main()