Skip to content

Commit 58a5f1e

Browse files
[1/2] Address security concerns in code (#626)
- [x] Address feedback on Threat and Vuln Analysis (TAVA) doc by ProdSec team - [x] Add note on safe usage of pickle deserialization of modelopt-generated state files **TODO: [Separate PR]** Replace pickle usage in `modelopt/torch/opt/plugins/megatron.py` - Needs fix on TransformerEngine first as we copy from https://github.com/NVIDIA/TransformerEngine/blob/3ff0b8d4/transformer_engine/pytorch/module/base.py#L863 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--trust_calibration_data` CLI flag for secure ONNX quantization with pickle data files. * **Improvements** * Enhanced security validation for generated quantization code. * Simplified data loading by removing pickle-based caching—data is now always loaded fresh. * Added security guidance throughout model state loading operations. * **Documentation** * Updated guides with security best practices for model state handling. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent def1e32 commit 58a5f1e

15 files changed

Lines changed: 104 additions & 30 deletions

File tree

docs/source/guides/2_save_load.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu
129129
model = ...
130130
131131
# Restore the model architecture using the saved `modelopt_state`
132+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
132133
modelopt_state = torch.load("modelopt_state.pth", weights_only=False)
133134
model = mto.restore_from_modelopt_state(model, modelopt_state)
134135

examples/llm_qat/export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def get_model(
5151

5252
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
5353
if hasattr(model, "peft_config"):
54+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
5455
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
5556
restore_from_modelopt_state(model, modelopt_state)
5657
print_rank_0("Restored modelopt state")

examples/llm_sparsity/weight_sparsity/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ python data_prep.py --save_path data
8484

8585
The following command demonstrates how to perform SAT on the Llama2-7B model on 8 GPUs.
8686
The model is finetuned on the [cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset for 3 epochs.
87-
The input data is tokenized to a maximum length of 1024 tokens. The tokenized data is saved as a pickle file for faster data loading. The one-time process takes less than an hour to finish depending on the CPU. The resulting pickle file can be utilized for future training sessions.
87+
The input data is tokenized to a maximum length of 1024 tokens.
8888

8989
```sh
9090
bash launch_finetune.sh --model meta-llama/Llama-2-7b-hf \

examples/llm_sparsity/weight_sparsity/finetune.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import argparse
3333
import copy
3434
import os
35-
import pickle
3635
from collections.abc import Sequence
3736
from dataclasses import dataclass, field
3837

@@ -232,27 +231,17 @@ def __init__(
232231
):
233232
super().__init__()
234233

235-
pickle_name = f"dict_{split}_{tokenizer.model_max_length}.pickle"
236234
with training_args.main_process_first():
237-
if os.path.isfile(pickle_name):
238-
with open(pickle_name, "rb") as f:
239-
print_rank_0("Reuse pickled data")
240-
data_dict = pickle.load(f)
241-
else:
242-
print_rank_0("Loading data...")
243-
list_data_dict = utils.jload(data_path)
244-
245-
print_rank_0("Formatting inputs...")
246-
prompt_input = PROMPT_DICT["prompt_input"]
247-
sources = [prompt_input.format_map(example) for example in list_data_dict]
248-
targets = [
249-
f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict
250-
]
251-
252-
print_rank_0("Tokenizing inputs... This may take some time...")
253-
data_dict = preprocess(sources, targets, tokenizer)
254-
with open(pickle_name, "wb") as f:
255-
pickle.dump(data_dict, f, pickle.HIGHEST_PROTOCOL)
235+
print_rank_0("Loading data...")
236+
list_data_dict = utils.jload(data_path)
237+
238+
print_rank_0("Formatting inputs...")
239+
prompt_input = PROMPT_DICT["prompt_input"]
240+
sources = [prompt_input.format_map(example) for example in list_data_dict]
241+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
242+
243+
print_rank_0("Tokenizing inputs... This may take some time...")
244+
data_dict = preprocess(sources, targets, tokenizer)
256245

257246
self.input_ids = data_dict["input_ids"]
258247
self.labels = data_dict["labels"]

modelopt/onnx/quantization/__main__.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Command-line entrypoint for ONNX PTQ."""
1717

1818
import argparse
19+
import os
1920

2021
import numpy as np
2122

@@ -24,6 +25,31 @@
2425
__all__ = ["main"]
2526

2627

28+
def validate_file_size(file_path: str, max_size_bytes: int) -> None:
29+
"""Validate that a file exists and does not exceed the maximum allowed size.
30+
31+
Args:
32+
file_path: Path to the file to validate
33+
max_size_bytes: Maximum allowed file size in bytes
34+
35+
Raises:
36+
FileNotFoundError: If the file does not exist
37+
ValueError: If the file exceeds the maximum allowed size
38+
"""
39+
if not os.path.exists(file_path):
40+
raise FileNotFoundError(f"File not found: {file_path}")
41+
42+
file_size = os.path.getsize(file_path)
43+
if file_size > max_size_bytes:
44+
max_size_gb = max_size_bytes / (1024 * 1024 * 1024)
45+
actual_size_gb = file_size / (1024 * 1024 * 1024)
46+
raise ValueError(
47+
f"File size validation failed: {file_path} ({actual_size_gb:.2f}GB) exceeds "
48+
f"maximum allowed size of {max_size_gb:.2f}GB. This limit helps prevent potential "
49+
f"denial-of-service attacks."
50+
)
51+
52+
2753
def get_parser() -> argparse.ArgumentParser:
2854
"""Get the argument parser for ONNX PTQ."""
2955
argparser = argparse.ArgumentParser("python -m modelopt.onnx.quantization")
@@ -52,6 +78,11 @@ def get_parser() -> argparse.ArgumentParser:
5278
type=str,
5379
help="Calibration data in npz/npy format. If None, random data for calibration will be used.",
5480
)
81+
group.add_argument(
82+
"--trust_calibration_data",
83+
action="store_true",
84+
help="If True, trust the calibration data and allow pickle deserialization.",
85+
)
5586
group.add_argument(
5687
"--calibration_cache_path",
5788
type=str,
@@ -261,12 +292,35 @@ def get_parser() -> argparse.ArgumentParser:
261292
def main():
262293
"""Command-line entrypoint for ONNX PTQ."""
263294
args = get_parser().parse_args()
295+
296+
# Security: Validate onnx model size is under 2GB by default
297+
if not args.use_external_data_format:
298+
try:
299+
validate_file_size(args.onnx_path, 2 * (1024**3))
300+
except ValueError as e:
301+
raise ValueError(
302+
"Onnx model size larger than 2GB. Please set --use_external_data_format flag to bypass this validation."
303+
) from e
304+
264305
calibration_data = None
265306
if args.calibration_data_path:
266-
calibration_data = np.load(args.calibration_data_path, allow_pickle=True)
267-
if args.calibration_data_path.endswith(".npz"):
268-
# Convert the NpzFile object to a Python dictionary
269-
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
307+
# Security: Disable pickle deserialization for untrusted sources to prevent RCE attacks
308+
try:
309+
calibration_data = np.load(
310+
args.calibration_data_path, allow_pickle=args.trust_calibration_data
311+
)
312+
if args.calibration_data_path.endswith(".npz"):
313+
# Convert the NpzFile object to a Python dictionary
314+
calibration_data = {key: calibration_data[key] for key in calibration_data.files}
315+
except ValueError as e:
316+
if "allow_pickle" in str(e) and not args.trust_calibration_data:
317+
raise ValueError(
318+
"Calibration data file contains pickled objects which pose a security risk. "
319+
"For trusted sources, you may enable pickle deserialization by setting the "
320+
"--trust_calibration_data flag."
321+
) from e
322+
else:
323+
raise
270324

271325
quantize(
272326
args.onnx_path,

modelopt/torch/export/distribute.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def read_configs_and_weights_from_rank(
9191
raise ValueError("NFSWorkspace is not initialized!")
9292
state_path = self._get_state_path(target_rank)
9393
if state_path.exists():
94+
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
9495
state = torch.load(state_path, map_location="cpu", weights_only=False)
9596
return state["config"], state["weight"]
9697
else:

modelopt/torch/opt/conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
526526
model = ... # Create the model-like object
527527
528528
# Restore the previously saved modelopt state followed by model weights
529+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
529530
mto.restore_from_modelopt_state(
530531
model, torch.load("modelopt_state.pt", weights_only=False)
531532
) # Restore modelopt state

modelopt/torch/opt/plugins/huggingface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def new_init_fn(self, *args, **kwargs):
7979
modelopt_state_path = _get_modelopt_state_path(model_path)
8080
_original__init__(self, *args, **kwargs)
8181
if os.path.isfile(modelopt_state_path):
82+
# Security NOTE: weights_only=False is used on ModelOpt-generated state_dict, not on untrusted user input
8283
modelopt_state = torch.load(modelopt_state_path, map_location="cpu", weights_only=False)
8384
with extra_context() if extra_context else nullcontext():
8485
restore_from_modelopt_state(self, modelopt_state)

modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def restore_sharded_modelopt_state(
242242
return
243243

244244
# Loading the common modelopt_state (replicated on all ranks)
245+
# Security NOTE: weights_only=False is used here on NVIDIA-generated file, not on untrusted user input
245246
common_modelopt_state = torch.load(
246247
modelopt_checkpoint_name + "/" + COMMON_STATE_FNAME, weights_only=False
247248
)

modelopt/torch/opt/plugins/megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _modelopt_set_extra_state(self, state: Any):
102102
# Default format: byte tensor with pickled data
103103
#
104104
# TODO: possible deserialization improvement
105-
# https://github.com/NVIDIA/TensorRT-LLM/commits/main/tensorrt_llm/serialization.py
105+
# https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/serialization.py
106106
extra_state = pickle.loads(state.detach().cpu().numpy().tobytes()) # nosec
107107
else:
108108
raise RuntimeError("Unsupported extra_state format.")

0 commit comments

Comments
 (0)