Skip to content

Commit b3bab4b

Browse files
author
Han Wang
committed
fix(pt_expt): use weights_only=True when reading .pt checkpoints
`Backend.detect_backend_by_model` and `pt_expt.DeepEval._load_pt` deserialised `.pt` files with `weights_only=False`, which allows arbitrary code execution from a malicious checkpoint. The training resume path (training.py:712) already uses `weights_only=True`; align the two new sites with that convention. Reported by chatgpt-codex-connector on PR #5423.
1 parent 7158830 commit b3bab4b

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

deepmd/backend/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def detect_backend_by_model(filename: str) -> type["Backend"]:
109109
try:
110110
import torch
111111

112-
sd = torch.load(filename, map_location="cpu", weights_only=False)
112+
# Use weights_only=True to avoid executing arbitrary pickle
113+
# from an untrusted .pt — sniffing only needs the dict keys.
114+
sd = torch.load(filename, map_location="cpu", weights_only=True)
113115
if isinstance(sd, dict) and "model" in sd:
114116
sd = sd["model"]
115117
keys = list(sd.keys()) if hasattr(sd, "keys") else []

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def _load_pt(self, model_file: str, head: str | None = None) -> None:
227227
get_model,
228228
)
229229

230-
state_dict = torch.load(model_file, map_location=DEVICE, weights_only=False)
230+
# Match the training resume path (training.py:712) — weights_only=True
231+
# avoids unpickling arbitrary code from untrusted checkpoints.
232+
state_dict = torch.load(model_file, map_location=DEVICE, weights_only=True)
231233
if "model" in state_dict:
232234
state_dict = state_dict["model"]
233235
model_params = deepcopy(state_dict["_extra_state"]["model_params"])

0 commit comments

Comments
 (0)