|
16 | 16 | """Command-line entrypoint for ONNX PTQ.""" |
17 | 17 |
|
18 | 18 | import argparse |
| 19 | +import os |
19 | 20 |
|
20 | 21 | import numpy as np |
21 | 22 |
|
|
24 | 25 | __all__ = ["main"] |
25 | 26 |
|
26 | 27 |
|
| 28 | +def validate_file_size(file_path: str, max_size_bytes: int) -> None: |
| 29 | + """Validate that a file exists and does not exceed the maximum allowed size. |
| 30 | +
|
| 31 | + Args: |
| 32 | + file_path: Path to the file to validate |
| 33 | + max_size_bytes: Maximum allowed file size in bytes |
| 34 | +
|
| 35 | + Raises: |
| 36 | + FileNotFoundError: If the file does not exist |
| 37 | + ValueError: If the file exceeds the maximum allowed size |
| 38 | + """ |
| 39 | + if not os.path.exists(file_path): |
| 40 | + raise FileNotFoundError(f"File not found: {file_path}") |
| 41 | + |
| 42 | + file_size = os.path.getsize(file_path) |
| 43 | + if file_size > max_size_bytes: |
| 44 | + max_size_gb = max_size_bytes / (1024 * 1024 * 1024) |
| 45 | + actual_size_gb = file_size / (1024 * 1024 * 1024) |
| 46 | + raise ValueError( |
| 47 | + f"File size validation failed: {file_path} ({actual_size_gb:.2f}GB) exceeds " |
| 48 | + f"maximum allowed size of {max_size_gb:.2f}GB. This limit helps prevent potential " |
| 49 | + f"denial-of-service attacks." |
| 50 | + ) |
| 51 | + |
| 52 | + |
27 | 53 | def get_parser() -> argparse.ArgumentParser: |
28 | 54 | """Get the argument parser for ONNX PTQ.""" |
29 | 55 | argparser = argparse.ArgumentParser("python -m modelopt.onnx.quantization") |
@@ -52,6 +78,11 @@ def get_parser() -> argparse.ArgumentParser: |
52 | 78 | type=str, |
53 | 79 | help="Calibration data in npz/npy format. If None, random data for calibration will be used.", |
54 | 80 | ) |
| 81 | + group.add_argument( |
| 82 | + "--trust_calibration_data", |
| 83 | + action="store_true", |
| 84 | + help="If True, trust the calibration data and allow pickle deserialization.", |
| 85 | + ) |
55 | 86 | group.add_argument( |
56 | 87 | "--calibration_cache_path", |
57 | 88 | type=str, |
@@ -261,12 +292,35 @@ def get_parser() -> argparse.ArgumentParser: |
261 | 292 | def main(): |
262 | 293 | """Command-line entrypoint for ONNX PTQ.""" |
263 | 294 | args = get_parser().parse_args() |
| 295 | + |
| 296 | + # Security: Validate onnx model size is under 2GB by default |
| 297 | + if not args.use_external_data_format: |
| 298 | + try: |
| 299 | + validate_file_size(args.onnx_path, 2 * (1024**3)) |
| 300 | + except ValueError as e: |
| 301 | + raise ValueError( |
| 302 | + "Onnx model size larger than 2GB. Please set --use_external_data_format flag to bypass this validation." |
| 303 | + ) from e |
| 304 | + |
264 | 305 | calibration_data = None |
265 | 306 | if args.calibration_data_path: |
266 | | - calibration_data = np.load(args.calibration_data_path, allow_pickle=True) |
267 | | - if args.calibration_data_path.endswith(".npz"): |
268 | | - # Convert the NpzFile object to a Python dictionary |
269 | | - calibration_data = {key: calibration_data[key] for key in calibration_data.files} |
| 307 | + # Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks |
| 308 | + try: |
| 309 | + calibration_data = np.load( |
| 310 | + args.calibration_data_path, allow_pickle=args.trust_calibration_data |
| 311 | + ) |
| 312 | + if args.calibration_data_path.endswith(".npz"): |
| 313 | + # Convert the NpzFile object to a Python dictionary |
| 314 | + calibration_data = {key: calibration_data[key] for key in calibration_data.files} |
| 315 | + except ValueError as e: |
| 316 | + if "allow_pickle" in str(e) and not args.trust_calibration_data: |
| 317 | + raise ValueError( |
| 318 | + "Calibration data file contains pickled objects which pose a security risk. " |
| 319 | + "For trusted sources, you may enable pickle deserialization by setting the " |
| 320 | + "--trust_calibration_data flag." |
| 321 | + ) from e |
| 322 | + else: |
| 323 | + raise |
270 | 324 |
|
271 | 325 | quantize( |
272 | 326 | args.onnx_path, |
|
0 commit comments