Skip to content

Commit 41e879b

Browse files
committed
Limit B614 to torch.load deserializers
Avoids false positives for torch.*.load helpers such as torch.utils.cpp_extension.load while preserving checks for torch.load and torch.serialization.load. Updated docstrings and example to reflect expected behavior. Resolves: #1343
1 parent 06fbbab commit 41e879b

2 files changed

Lines changed: 15 additions & 16 deletions

File tree

bandit/plugins/pytorch_load.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
B614: Test for unsafe PyTorch load
77
==================================
88
9-
This plugin checks for unsafe use of `torch.load`. Using `torch.load` with
10-
untrusted data can lead to arbitrary code execution. There are two safe
11-
alternatives:
9+
This plugin checks for unsafe use of `torch.load` and `torch.serialization.load`.
10+
Using `torch.load` or `torch.serialization.load` with untrusted data can lead to
11+
arbitrary code execution. There are two safe alternatives:
1212
1313
1. Use `torch.load` with `weights_only=True` where only tensor data is
1414
extracted, and no arbitrary Python objects are deserialized
@@ -24,7 +24,7 @@
2424
2525
>> Issue: Use of unsafe PyTorch load
2626
Severity: Medium Confidence: High
27-
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
27+
CWE: CWE-502 (https://cwe.mitre.org/data/definitions/502.html)
2828
Location: examples/pytorch_load_save.py:8
2929
7 loaded_model.load_state_dict(torch.load('model_weights.pth'))
3030
8 another_model.load_state_dict(torch.load('model_weights.pth',
@@ -34,7 +34,7 @@
3434
3535
.. seealso::
3636
37-
- https://cwe.mitre.org/data/definitions/94.html
37+
- https://cwe.mitre.org/data/definitions/502.html
3838
- https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
3939
- https://github.com/huggingface/safetensors
4040
@@ -50,23 +50,17 @@
5050
@test.test_id("B614")
5151
def pytorch_load(context):
5252
"""
53-
This plugin checks for unsafe use of `torch.load`. Using `torch.load`
54-
with untrusted data can lead to arbitrary code execution. The safe
55-
alternative is to use `weights_only=True` or the safetensors library.
53+
This plugin checks for unsafe use of `torch.load` and `torch.serialization.load`.
54+
Using `torch.load` or `torch.serialization.load` with untrusted data can lead to
55+
arbitrary code execution. The safe alternative is to use `weights_only=True` or the
56+
safetensors library.
5657
"""
5758
imported = context.is_module_imported_exact("torch")
5859
qualname = context.call_function_name_qual
5960
if not imported and isinstance(qualname, str):
6061
return
6162

62-
qualname_list = qualname.split(".")
63-
func = qualname_list[-1]
64-
if all(
65-
[
66-
"torch" in qualname_list,
67-
func == "load",
68-
]
69-
):
63+
if qualname in {"torch.load", "torch.serialization.load"}:
7064
# For torch.load, check if weights_only=True is specified
7165
weights_only = context.get_call_arg_value("weights_only")
7266
if weights_only == "True" or weights_only is True:

examples/pytorch_load.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@
2424
# Example of loading with both map_location and weights_only=True (should NOT trigger B614)
2525
safe_cpu_model = models.resnet18()
2626
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True))
27+
28+
# Example of a torch.*.load call that should NOT trigger B614
29+
# Only pickle deserializers should trigger B614
30+
if False: # Static analysis only; does not execute
31+
torch.utils.cpp_extension.load(name="example_ext", sources=[])

0 commit comments

Comments
 (0)