Skip to content

Commit cdb5082

Browse files
SunsetWolfhugo2046
authored andcommitted
fix(security): restrict pickle deserialization to safe classes (microsoft#2076)
1 parent fccbf03 commit cdb5082

3 files changed

Lines changed: 180 additions & 8 deletions

File tree

qlib/rl/data/native.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5+
import os
56
from pathlib import Path
6-
from typing import cast, List
7+
from typing import List, cast
78

89
import cachetools
910
import pandas as pd
10-
import pickle
11-
import os
1211

1312
from qlib.backtest import Exchange, Order
1413
from qlib.backtest.decision import TradeRange, TradeRangeByTime
1514
from qlib.constant import EPS_T
15+
from qlib.utils.pickle_utils import restricted_pickle_load
16+
1617
from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider
1718

1819

@@ -162,7 +163,7 @@ def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame:
162163
path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl")
163164
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
164165
with open(path, "rb") as fstream:
165-
dataset = pickle.load(fstream)
166+
dataset = restricted_pickle_load(fstream)
166167
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
167168

168169
if index_only:

qlib/utils/pickle_utils.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""
4+
Secure pickle utilities to prevent arbitrary code execution through deserialization.
5+
6+
This module provides a secure alternative to pickle.load() and pickle.loads()
7+
that restricts deserialization to a whitelist of safe classes.
8+
"""
9+
10+
import io
11+
import pickle
12+
from typing import Any, BinaryIO, Set, Tuple
13+
14+
# Whitelist of safe classes that are allowed to be unpickled
15+
# These are common data types used in qlib that should be safe to deserialize
16+
SAFE_PICKLE_CLASSES: Set[Tuple[str, str]] = {
17+
# python builtins
18+
("builtins", "slice"),
19+
("builtins", "range"),
20+
("builtins", "dict"),
21+
("builtins", "list"),
22+
("builtins", "tuple"),
23+
("builtins", "set"),
24+
("builtins", "frozenset"),
25+
("builtins", "bytearray"),
26+
("builtins", "bytes"),
27+
("builtins", "str"),
28+
("builtins", "int"),
29+
("builtins", "float"),
30+
("builtins", "bool"),
31+
("builtins", "complex"),
32+
("builtins", "type"),
33+
("builtins", "property"),
34+
# common utility classes
35+
("datetime", "datetime"),
36+
("datetime", "date"),
37+
("datetime", "time"),
38+
("datetime", "timedelta"),
39+
("datetime", "timezone"),
40+
("decimal", "Decimal"),
41+
("collections", "OrderedDict"),
42+
("collections", "defaultdict"),
43+
("collections", "Counter"),
44+
("collections", "namedtuple"),
45+
("enum", "Enum"),
46+
("pathlib", "Path"),
47+
("pathlib", "PosixPath"),
48+
("pathlib", "WindowsPath"),
49+
("qlib.data.dataset.handler", "DataHandler"),
50+
("qlib.data.dataset.handler", "DataHandlerLP"),
51+
("qlib.data.dataset.loader", "StaticDataLoader"),
52+
}
53+
54+
55+
TRUSTED_MODULE_PREFIXES = (
56+
"pandas",
57+
"numpy",
58+
)
59+
60+
61+
class RestrictedUnpickler(pickle.Unpickler):
62+
"""Custom unpickler that only allows safe classes to be deserialized.
63+
64+
This prevents arbitrary code execution through malicious pickle files by
65+
restricting deserialization to a whitelist of safe classes.
66+
67+
Example:
68+
>>> with open("data.pkl", "rb") as f:
69+
... data = RestrictedUnpickler(f).load()
70+
"""
71+
72+
def find_class(self, module: str, name: str):
73+
"""Override find_class to restrict allowed classes.
74+
75+
Args:
76+
module: Module name of the class
77+
name: Class name
78+
79+
Returns:
80+
The class object if it's in the whitelist
81+
82+
Raises:
83+
pickle.UnpicklingError: If the class is not in the whitelist
84+
"""
85+
if module.startswith(TRUSTED_MODULE_PREFIXES):
86+
return super().find_class(module, name)
87+
88+
# 2. explicit whitelist (qlib internal)
89+
if (module, name) in SAFE_PICKLE_CLASSES:
90+
return super().find_class(module, name)
91+
92+
raise pickle.UnpicklingError(
93+
f"Forbidden class: {module}.{name}. "
94+
f"Only whitelisted classes are allowed for security reasons. "
95+
f"This is to prevent arbitrary code execution through pickle deserialization."
96+
)
97+
98+
99+
def restricted_pickle_load(file: BinaryIO) -> Any:
100+
"""Safely load a pickle file with restricted classes.
101+
102+
This is a drop-in replacement for pickle.load() that prevents
103+
arbitrary code execution by only allowing whitelisted classes.
104+
105+
Args:
106+
file: An opened file object in binary mode
107+
108+
Returns:
109+
The unpickled Python object
110+
111+
Raises:
112+
pickle.UnpicklingError: If the pickle contains forbidden classes
113+
114+
Example:
115+
>>> with open("data.pkl", "rb") as f:
116+
... data = restricted_pickle_load(f)
117+
"""
118+
return RestrictedUnpickler(file).load()
119+
120+
121+
def restricted_pickle_loads(data: bytes) -> Any:
122+
"""Safely load a pickle from bytes with restricted classes.
123+
124+
This is a drop-in replacement for pickle.loads() that prevents
125+
arbitrary code execution by only allowing whitelisted classes.
126+
127+
Args:
128+
data: Bytes object containing pickled data
129+
130+
Returns:
131+
The unpickled Python object
132+
133+
Raises:
134+
pickle.UnpicklingError: If the pickle contains forbidden classes
135+
136+
Example:
137+
>>> data = b'\\x80\\x04\\x95...'
138+
>>> obj = restricted_pickle_loads(data)
139+
"""
140+
file_like = io.BytesIO(data)
141+
return RestrictedUnpickler(file_like).load()
142+
143+
144+
def add_safe_class(module: str, name: str) -> None:
145+
"""Add a class to the whitelist of safe classes for unpickling.
146+
147+
Use this function to extend the whitelist if your code needs to deserialize
148+
additional classes. However, be very careful when adding classes, as this
149+
could potentially introduce security vulnerabilities.
150+
151+
Args:
152+
module: Module name of the class (e.g., 'my_package.my_module')
153+
name: Class name (e.g., 'MyClass')
154+
155+
Warning:
156+
Only add classes that you fully control and trust. Adding arbitrary
157+
classes from external packages could introduce security risks.
158+
159+
Example:
160+
>>> add_safe_class('my_package.models', 'CustomModel')
161+
"""
162+
SAFE_PICKLE_CLASSES.add((module, name))
163+
164+
165+
def get_safe_classes() -> Set[Tuple[str, str]]:
166+
"""Get a copy of the current whitelist of safe classes.
167+
168+
Returns:
169+
A set of (module, name) tuples representing allowed classes
170+
"""
171+
return SAFE_PICKLE_CLASSES.copy()

tests/data_mid_layer_tests/test_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
2-
import pickle
3-
import shutil
42
import unittest
5-
from qlib.tests import TestAutoData
3+
64
from qlib.data import D
75
from qlib.data.dataset.handler import DataHandlerLP
6+
from qlib.tests import TestAutoData
7+
from qlib.utils.pickle_utils import restricted_pickle_load
88

99

1010
class HandlerTests(TestAutoData):
@@ -23,7 +23,7 @@ def test_handler_df(self):
2323
dh.to_pickle(fname, dump_all=True)
2424

2525
with open(fname, "rb") as f:
26-
dh_d = pickle.load(f)
26+
dh_d = restricted_pickle_load(f)
2727

2828
self.assertTrue(dh_d._data.equals(df))
2929
self.assertTrue(dh_d._infer is dh_d._data)

0 commit comments

Comments
 (0)