Skip to content

Commit 4137e24

Browse files
sararobcopybara-github
authored andcommitted
fix: add safeguards and warnings for remote code execution during pickle-based model deserialization
PiperOrigin-RevId: 874058205
1 parent a204e74 commit 4137e24

File tree

6 files changed

+215
-10
lines changed

6 files changed

+215
-10
lines changed

google/cloud/aiplatform/metadata/_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import pickle
2121
import tempfile
22+
import warnings
2223
from typing import Any, Dict, Optional, Sequence, Union
2324

2425
from google.auth import credentials as auth_credentials
@@ -147,6 +148,13 @@ def _load_sklearn_model(
147148
f"You are using sklearn {sklearn.__version__}."
148149
"Attempting to load model..."
149150
)
151+
152+
warnings.warn(
153+
"Loading a scikit-learn model via pickle is insecure. "
154+
"Ensure the model artifact is from a trusted source.",
155+
RuntimeWarning,
156+
)
157+
150158
with open(model_file, "rb") as f:
151159
sk_model = pickle.load(f)
152160

google/cloud/aiplatform/prediction/predictor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,15 @@ def __init__(self):
4040
return
4141

4242
@abstractmethod
43-
def load(self, artifacts_uri: str) -> None:
43+
def load(self, artifacts_uri: str, **kwargs) -> None:
4444
"""Loads the model artifact.
4545
4646
Args:
4747
artifacts_uri (str):
4848
Required. The value of the environment variable AIP_STORAGE_URI.
49+
**kwargs:
50+
Optional. Additional keyword arguments for security or
51+
configuration (e.g., allowed_extensions).
4952
"""
5053
pass
5154

google/cloud/aiplatform/prediction/sklearn/predictor.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import os
2121
import pickle
22+
import warnings
2223

2324
from google.cloud.aiplatform.constants import prediction
2425
from google.cloud.aiplatform.utils import prediction_utils
@@ -31,29 +32,65 @@ class SklearnPredictor(Predictor):
3132
def __init__(self):
3233
return
3334

34-
def load(self, artifacts_uri: str) -> None:
35+
def load(self, artifacts_uri: str, **kwargs) -> None:
3536
"""Loads the model artifact.
3637
3738
Args:
3839
artifacts_uri (str):
3940
Required. The value of the environment variable AIP_STORAGE_URI.
41+
**kwargs:
42+
Optional. Additional keyword arguments for security or
43+
configuration. Supported arguments:
44+
allowed_extensions (list[str]):
45+
The allowed file extensions for model artifacts.
46+
If not provided, a UserWarning is issued.
4047
4148
Raises:
4249
ValueError: If there's no required model files provided in the artifacts
4350
uri.
4451
"""
52+
53+
allowed_extensions = kwargs.get("allowed_extensions", None)
54+
55+
if allowed_extensions is None:
56+
warnings.warn(
57+
"No 'allowed_extensions' provided. Loading model artifacts from "
58+
"untrusted sources may lead to remote code execution.",
59+
UserWarning,
60+
)
61+
4562
prediction_utils.download_model_artifacts(artifacts_uri)
46-
if os.path.exists(prediction.MODEL_FILENAME_JOBLIB):
63+
if os.path.exists(
64+
prediction.MODEL_FILENAME_JOBLIB
65+
) and prediction_utils.is_allowed(
66+
filename=prediction.MODEL_FILENAME_JOBLIB,
67+
allowed_extensions=allowed_extensions,
68+
):
69+
warnings.warn(
70+
f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. "
71+
"Only load files from trusted sources.",
72+
RuntimeWarning,
73+
)
4774
self._model = joblib.load(prediction.MODEL_FILENAME_JOBLIB)
48-
elif os.path.exists(prediction.MODEL_FILENAME_PKL):
75+
elif os.path.exists(
76+
prediction.MODEL_FILENAME_PKL
77+
) and prediction_utils.is_allowed(
78+
filename=prediction.MODEL_FILENAME_PKL,
79+
allowed_extensions=allowed_extensions,
80+
):
81+
warnings.warn(
82+
f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. "
83+
"Only load files from trusted sources.",
84+
RuntimeWarning,
85+
)
4986
self._model = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb"))
5087
else:
5188
valid_filenames = [
5289
prediction.MODEL_FILENAME_JOBLIB,
5390
prediction.MODEL_FILENAME_PKL,
5491
]
5592
raise ValueError(
56-
f"One of the following model files must be provided: {valid_filenames}."
93+
f"One of the following model files must be provided and allowed: {valid_filenames}."
5794
)
5895

5996
def preprocess(self, prediction_input: dict) -> np.ndarray:

google/cloud/aiplatform/prediction/xgboost/predictor.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import logging
2020
import os
2121
import pickle
22+
import warnings
2223

2324
import numpy as np
2425
import xgboost as xgb
@@ -34,21 +35,52 @@ class XgboostPredictor(Predictor):
3435
def __init__(self):
3536
return
3637

37-
def load(self, artifacts_uri: str) -> None:
38+
def load(self, artifacts_uri: str, **kwargs) -> None:
3839
"""Loads the model artifact.
3940
4041
Args:
4142
artifacts_uri (str):
4243
Required. The value of the environment variable AIP_STORAGE_URI.
44+
**kwargs:
45+
Optional. Additional keyword arguments for security or
46+
configuration. Supported arguments:
47+
allowed_extensions (list[str]):
48+
The allowed file extensions for model artifacts.
49+
If not provided, a UserWarning is issued.
4350
4451
Raises:
4552
ValueError: If there's no required model files provided in the artifacts
4653
uri.
4754
"""
55+
allowed_extensions = kwargs.get("allowed_extensions", None)
56+
57+
if allowed_extensions is None:
58+
warnings.warn(
59+
"No 'allowed_extensions' provided. Loading model artifacts from "
60+
"untrusted sources may lead to remote code execution.",
61+
UserWarning,
62+
)
63+
4864
prediction_utils.download_model_artifacts(artifacts_uri)
49-
if os.path.exists(prediction.MODEL_FILENAME_BST):
65+
66+
if os.path.exists(
67+
prediction.MODEL_FILENAME_BST
68+
) and prediction_utils.is_allowed(
69+
filename=prediction.MODEL_FILENAME_BST,
70+
allowed_extensions=allowed_extensions,
71+
):
5072
booster = xgb.Booster(model_file=prediction.MODEL_FILENAME_BST)
51-
elif os.path.exists(prediction.MODEL_FILENAME_JOBLIB):
73+
elif os.path.exists(
74+
prediction.MODEL_FILENAME_JOBLIB
75+
) and prediction_utils.is_allowed(
76+
filename=prediction.MODEL_FILENAME_JOBLIB,
77+
allowed_extensions=allowed_extensions,
78+
):
79+
warnings.warn(
80+
f"Loading {prediction.MODEL_FILENAME_JOBLIB} using joblib pickle, which is unsafe. "
81+
"Only load files from trusted sources.",
82+
RuntimeWarning,
83+
)
5284
try:
5385
booster = joblib.load(prediction.MODEL_FILENAME_JOBLIB)
5486
except KeyError:
@@ -58,7 +90,17 @@ def load(self, artifacts_uri: str) -> None:
5890
)
5991
booster = xgb.Booster()
6092
booster.load_model(prediction.MODEL_FILENAME_JOBLIB)
61-
elif os.path.exists(prediction.MODEL_FILENAME_PKL):
93+
elif os.path.exists(
94+
prediction.MODEL_FILENAME_PKL
95+
) and prediction_utils.is_allowed(
96+
filename=prediction.MODEL_FILENAME_PKL,
97+
allowed_extensions=allowed_extensions,
98+
):
99+
warnings.warn(
100+
f"Loading {prediction.MODEL_FILENAME_PKL} using pickle, which is unsafe. "
101+
"Only load files from trusted sources.",
102+
RuntimeWarning,
103+
)
62104
booster = pickle.load(open(prediction.MODEL_FILENAME_PKL, "rb"))
63105
else:
64106
valid_filenames = [
@@ -67,7 +109,7 @@ def load(self, artifacts_uri: str) -> None:
67109
prediction.MODEL_FILENAME_PKL,
68110
]
69111
raise ValueError(
70-
f"One of the following model files must be provided: {valid_filenames}."
112+
f"One of the following model files must be provided and allowed: {valid_filenames}."
71113
)
72114
self._booster = booster
73115

google/cloud/aiplatform/utils/prediction_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,9 @@ def add_flex_start_to_dedicated_resources(
174174
dedicated_resources.flex_start = gca_machine_resources_compat.FlexStart(
175175
max_runtime_duration=duration_pb2.Duration(seconds=max_runtime_duration)
176176
)
177+
178+
179+
def is_allowed(filename: str, allowed_extensions: Optional[list[str]]) -> bool:
180+
if allowed_extensions is None:
181+
return True
182+
return any(filename.endswith(ext) for ext in allowed_extensions)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2026 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
from unittest import mock
20+
from google.cloud.aiplatform.prediction.xgboost.predictor import (
21+
XgboostPredictor,
22+
)
23+
from google.cloud.aiplatform.prediction.sklearn.predictor import (
24+
SklearnPredictor,
25+
)
26+
27+
from google.cloud.aiplatform.utils import prediction_utils
28+
from google.cloud.aiplatform.prediction.xgboost import (
29+
predictor as xgboost_predictor,
30+
)
31+
32+
33+
class TestPredictorSecurity:
34+
@pytest.mark.parametrize("predictor_class", [XgboostPredictor, SklearnPredictor])
35+
def test_load_warns_no_allowed_extensions(self, predictor_class):
36+
"""Verifies UserWarning is issued when allowed_extensions is missing."""
37+
predictor = predictor_class()
38+
with mock.patch.object(
39+
prediction_utils, "download_model_artifacts"
40+
), mock.patch("os.path.exists", return_value=True), mock.patch(
41+
"joblib.load"
42+
), mock.patch(
43+
"pickle.load"
44+
), mock.patch.object(
45+
xgboost_predictor.xgb, "Booster"
46+
), mock.patch(
47+
"builtins.open", mock.mock_open()
48+
):
49+
with pytest.warns(UserWarning, match="No 'allowed_extensions' provided"):
50+
predictor.load("gs://test-bucket")
51+
52+
def test_xgboost_load_warns_on_joblib(self):
53+
"""Verifies RuntimeWarning is issued when loading a .joblib file."""
54+
predictor = XgboostPredictor()
55+
with mock.patch.object(
56+
prediction_utils, "download_model_artifacts"
57+
), mock.patch(
58+
"os.path.exists", side_effect=lambda p: p.endswith(".joblib")
59+
), mock.patch(
60+
"joblib.load"
61+
), mock.patch.object(
62+
xgboost_predictor.xgb, "Booster"
63+
):
64+
with pytest.warns(
65+
RuntimeWarning, match="using joblib pickle, which is unsafe"
66+
):
67+
predictor.load("gs://test-bucket", allowed_extensions=[".joblib"])
68+
69+
def test_xgboost_load_raises_not_allowed(self):
70+
"""Verifies ValueError is raised if the file exists but is not allowed."""
71+
predictor = XgboostPredictor()
72+
with mock.patch.object(
73+
prediction_utils, "download_model_artifacts"
74+
), mock.patch.object(xgboost_predictor.xgb, "Booster"), mock.patch(
75+
"os.path.exists", side_effect=lambda p: p.endswith(".pkl")
76+
):
77+
with pytest.raises(ValueError, match="must be provided and allowed"):
78+
predictor.load("gs://test-bucket", allowed_extensions=[".bst"])
79+
80+
def test_sklearn_load_warns_on_pickle(self):
81+
"""Verifies RuntimeWarning is issued when loading a .pkl file."""
82+
predictor = SklearnPredictor()
83+
with mock.patch.object(
84+
prediction_utils, "download_model_artifacts"
85+
), mock.patch(
86+
"os.path.exists", side_effect=lambda p: p.endswith(".pkl")
87+
), mock.patch(
88+
"builtins.open", mock.mock_open()
89+
), mock.patch(
90+
"pickle.load"
91+
):
92+
93+
with pytest.warns(RuntimeWarning, match="using pickle, which is unsafe"):
94+
predictor.load("gs://test-bucket", allowed_extensions=[".pkl"])
95+
96+
def test_sklearn_load_warns_on_joblib(self):
97+
"""Verifies RuntimeWarning is issued when loading a .joblib file in Scikit-learn."""
98+
predictor = SklearnPredictor()
99+
with mock.patch.object(
100+
prediction_utils, "download_model_artifacts"
101+
), mock.patch(
102+
"os.path.exists", side_effect=lambda p: p.endswith(".joblib")
103+
), mock.patch(
104+
"joblib.load"
105+
):
106+
with pytest.warns(
107+
RuntimeWarning, match=r"using joblib pickle, which is unsafe"
108+
):
109+
predictor.load("gs://test-bucket", allowed_extensions=[".joblib"])

0 commit comments

Comments
 (0)