Skip to content

Commit 1d055e3

Browse files
wyf7107copybara-github
authored andcommitted
fix(migration): restrict unpickling of v0 actions blobs
Port of GitHub PR: 9db48ce Restrict unpickling of v0 actions blobs during session migration to prevent unsafe deserialization. Add --allow-unsafe-unpickling option to allow unsafe unpickling for trusted databases. Co-authored-by: Yifan Wang <wanyif@google.com> PiperOrigin-RevId: 927430066
1 parent 8befdb8 commit 1d055e3

5 files changed

Lines changed: 533 additions & 15 deletions

File tree

src/google/adk/cli/cli_tools_click.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,15 +2205,33 @@ def migrate():
22052205
default="INFO",
22062206
help="Optional. Set the logging level",
22072207
)
2208+
@click.option( # type: ignore[untyped-decorator]
2209+
"--allow-unsafe-unpickling",
2210+
"--allow_unsafe_unpickling",
2211+
is_flag=True,
2212+
default=False,
2213+
help=(
2214+
"Optional. Allow unsafe pickle loading for trusted legacy session"
2215+
" databases."
2216+
),
2217+
)
22082218
def cli_migrate_session(
2209-
*, source_db_url: str, dest_db_url: str, log_level: str
2219+
*,
2220+
source_db_url: str,
2221+
dest_db_url: str,
2222+
log_level: str,
2223+
allow_unsafe_unpickling: bool,
22102224
):
22112225
"""Migrates a session database to the latest schema version."""
22122226
logs.setup_adk_logger(getattr(logging, log_level.upper()))
22132227
try:
22142228
from ..sessions.migration import migration_runner
22152229

2216-
migration_runner.upgrade(source_db_url, dest_db_url)
2230+
migration_runner.upgrade(
2231+
source_db_url,
2232+
dest_db_url,
2233+
allow_unsafe_unpickling=allow_unsafe_unpickling,
2234+
)
22172235
click.secho("Migration check and upgrade process finished.", fg="green")
22182236
except Exception as e:
22192237
click.secho(f"Migration failed: {e}", fg="red", err=True)

src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py

Lines changed: 140 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import argparse
1919
from datetime import datetime
2020
from datetime import timezone
21+
import io
2122
import json
2223
import logging
2324
import pickle
@@ -37,6 +38,93 @@
3738

3839
logger = logging.getLogger("google_adk." + __name__)
3940

41+
_ALLOWED_PICKLE_GLOBALS = {
42+
# Basic types/containers
43+
("builtins", "dict"),
44+
("builtins", "list"),
45+
("builtins", "set"),
46+
("builtins", "tuple"),
47+
("builtins", "str"),
48+
("builtins", "bytes"),
49+
("builtins", "bytearray"),
50+
("builtins", "int"),
51+
("builtins", "float"),
52+
("builtins", "bool"),
53+
("datetime", "datetime"),
54+
("datetime", "timedelta"),
55+
("datetime", "timezone"),
56+
# Expected pickled payload for v0 session schema events.
57+
("fastapi.openapi.models", "APIKey"),
58+
("fastapi.openapi.models", "APIKeyIn"),
59+
("fastapi.openapi.models", "HTTPBase"),
60+
("fastapi.openapi.models", "HTTPBearer"),
61+
("fastapi.openapi.models", "OAuth2"),
62+
("fastapi.openapi.models", "OAuthFlow"),
63+
("fastapi.openapi.models", "OAuthFlowAuthorizationCode"),
64+
("fastapi.openapi.models", "OAuthFlowClientCredentials"),
65+
("fastapi.openapi.models", "OAuthFlowImplicit"),
66+
("fastapi.openapi.models", "OAuthFlowPassword"),
67+
("fastapi.openapi.models", "OAuthFlows"),
68+
("fastapi.openapi.models", "OpenIdConnect"),
69+
("fastapi.openapi.models", "SecurityBase"),
70+
("fastapi.openapi.models", "SecurityScheme"),
71+
("fastapi.openapi.models", "SecuritySchemeType"),
72+
("google.adk.auth.auth_credential", "AuthCredential"),
73+
("google.adk.auth.auth_credential", "AuthCredentialTypes"),
74+
("google.adk.auth.auth_credential", "HttpAuth"),
75+
("google.adk.auth.auth_credential", "HttpCredentials"),
76+
("google.adk.auth.auth_credential", "OAuth2Auth"),
77+
("google.adk.auth.auth_credential", "ServiceAccountCredential"),
78+
("google.adk.auth.auth_schemes", "CustomAuthScheme"),
79+
("google.adk.auth.auth_schemes", "ExtendedOAuth2"),
80+
("google.adk.auth.auth_schemes", "OAuthGrantType"),
81+
("google.adk.auth.auth_schemes", "OpenIdConnectWithConfig"),
82+
("google.adk.auth.auth_tool", "AuthConfig"),
83+
("google.adk.events.event_actions", "EventActions"),
84+
("google.adk.events.event_actions", "EventCompaction"),
85+
("google.adk.events.ui_widget", "UiWidget"),
86+
("google.adk.tools.tool_confirmation", "ToolConfirmation"),
87+
("google.genai.types", "Blob"),
88+
("google.genai.types", "CodeExecutionResult"),
89+
("google.genai.types", "Content"),
90+
("google.genai.types", "ExecutableCode"),
91+
("google.genai.types", "FileData"),
92+
("google.genai.types", "FunctionCall"),
93+
("google.genai.types", "FunctionResponse"),
94+
("google.genai.types", "FunctionResponseBlob"),
95+
("google.genai.types", "FunctionResponseFileData"),
96+
("google.genai.types", "FunctionResponsePart"),
97+
("google.genai.types", "Part"),
98+
("google.genai.types", "PartMediaResolution"),
99+
("google.genai.types", "VideoMetadata"),
100+
}
101+
102+
103+
class _RestrictedUnpickler(pickle.Unpickler):
104+
"""Restricted unpickler for migrating legacy v0 schema actions.
105+
106+
The v0 session schema stored `EventActions` as a pickled blob. During
107+
migration we treat the raw bytes read from the source DB as untrusted input
108+
and only allow the minimum set of safe globals needed to reconstruct
109+
`EventActions`.
110+
"""
111+
112+
def find_class(self, module: str, name: str) -> Any: # noqa: ANN001
113+
if (module, name) in _ALLOWED_PICKLE_GLOBALS:
114+
return super().find_class(module, name)
115+
raise pickle.UnpicklingError(
116+
f"Blocked global during migration unpickle: {module}.{name}"
117+
)
118+
119+
120+
def _restricted_pickle_loads(
121+
data: bytes, *, allow_unsafe_unpickling: bool = False
122+
) -> Any:
123+
"""Load a pickle payload using the restricted unpickler by default."""
124+
if allow_unsafe_unpickling:
125+
return pickle.loads(data)
126+
return _RestrictedUnpickler(io.BytesIO(data)).load()
127+
40128

41129
def _to_datetime_obj(val: Any) -> datetime | Any:
42130
"""Converts string to datetime if needed."""
@@ -51,15 +139,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any:
51139
return val
52140

53141

54-
def _row_to_event(row: dict) -> Event:
142+
def _row_to_event(
143+
row: dict[str, Any], *, allow_unsafe_unpickling: bool = False
144+
) -> Event:
55145
"""Converts event row (dict) to event object, handling missing columns and deserializing."""
56146

57147
actions_val = row.get("actions")
58148
actions = None
59149
if actions_val is not None:
60150
try:
61151
if isinstance(actions_val, bytes):
62-
actions = pickle.loads(actions_val)
152+
actions = _restricted_pickle_loads(
153+
actions_val, allow_unsafe_unpickling=allow_unsafe_unpickling
154+
)
63155
else: # for spanner - it might return object directly
64156
actions = actions_val
65157
except Exception as e:
@@ -75,17 +167,25 @@ def _row_to_event(row: dict) -> Event:
75167
else:
76168
actions = EventActions()
77169

78-
def _safe_json_load(val):
79-
data = None
170+
def _safe_json_load(val: Any) -> dict[str, Any] | None:
80171
if isinstance(val, str):
81172
try:
82173
data = json.loads(val)
83174
except json.JSONDecodeError:
84175
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
85176
return None
86177
elif isinstance(val, dict):
87-
data = val # for postgres JSONB
88-
return data
178+
return val # for postgres JSONB
179+
else:
180+
return None
181+
182+
if isinstance(data, dict):
183+
return data
184+
logger.warning(
185+
f"Expected JSON object for event {row.get('id')}, got"
186+
f" {type(data).__name__}."
187+
)
188+
return None
89189

90190
content_dict = _safe_json_load(row.get("content"))
91191
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
@@ -147,23 +247,31 @@ def _safe_json_load(val):
147247
)
148248

149249

150-
def _get_state_dict(state_val: Any) -> dict:
250+
def _get_state_dict(state_val: Any) -> dict[str, Any]:
151251
"""Safely load dict from JSON string or return dict if already dict."""
152252
if isinstance(state_val, dict):
153253
return state_val
154254
if isinstance(state_val, str):
155255
try:
156-
return json.loads(state_val)
256+
data = json.loads(state_val)
157257
except json.JSONDecodeError:
158258
logger.warning(
159259
"Failed to parse state JSON string, defaulting to empty dict."
160260
)
161261
return {}
262+
if isinstance(data, dict):
263+
return data
264+
logger.warning("State JSON was not an object, defaulting to empty dict.")
265+
return {}
162266
return {}
163267

164268

165269
# --- Migration Logic ---
166-
def migrate(source_db_url: str, dest_db_url: str):
270+
def migrate(
271+
source_db_url: str,
272+
dest_db_url: str,
273+
allow_unsafe_unpickling: bool = False,
274+
) -> None:
167275
"""Migrates data from old pickle schema to new JSON schema."""
168276
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
169277
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
@@ -172,6 +280,11 @@ def migrate(source_db_url: str, dest_db_url: str):
172280
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)
173281

174282
logger.info(f"Connecting to source database: {source_db_url}")
283+
if allow_unsafe_unpickling:
284+
logger.warning(
285+
"Unsafe pickle migration mode is enabled. Only use this with a trusted"
286+
" source database."
287+
)
175288
try:
176289
source_engine = create_engine(source_sync_url)
177290
SourceSession = sessionmaker(bind=source_engine)
@@ -265,7 +378,10 @@ def migrate(source_db_url: str, dest_db_url: str):
265378
text("SELECT * FROM events")
266379
).mappings():
267380
try:
268-
event_obj = _row_to_event(dict(row))
381+
event_obj = _row_to_event(
382+
dict(row),
383+
allow_unsafe_unpickling=allow_unsafe_unpickling,
384+
)
269385
new_event = v1.StorageEvent(
270386
id=event_obj.id,
271387
app_name=row["app_name"],
@@ -309,9 +425,22 @@ def migrate(source_db_url: str, dest_db_url: str):
309425
required=True,
310426
help="SQLAlchemy URL of destination database",
311427
)
428+
parser.add_argument(
429+
"--allow_unsafe_unpickling",
430+
"--allow-unsafe-unpickling",
431+
action="store_true",
432+
help=(
433+
"Allow legacy pickle payloads to use Python's unsafe pickle loader."
434+
" Only use this with a trusted source database."
435+
),
436+
)
312437
args = parser.parse_args()
313438
try:
314-
migrate(args.source_db_url, args.dest_db_url)
439+
migrate(
440+
args.source_db_url,
441+
args.dest_db_url,
442+
allow_unsafe_unpickling=args.allow_unsafe_unpickling,
443+
)
315444
except Exception as e:
316445
logger.error(f"Migration failed: {e}")
317446
sys.exit(1)

src/google/adk/sessions/migration/migration_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION
4343

4444

45-
def upgrade(source_db_url: str, dest_db_url: str):
45+
def upgrade(
46+
source_db_url: str,
47+
dest_db_url: str,
48+
allow_unsafe_unpickling: bool = False,
49+
) -> None:
4650
"""Migrates a database from its current version to the latest version.
4751
4852
If the source database schema is older than the latest version, this
@@ -61,6 +65,9 @@ def upgrade(source_db_url: str, dest_db_url: str):
6165
source_db_url: The SQLAlchemy URL of the database to migrate from.
6266
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
6367
different from source_db_url.
68+
allow_unsafe_unpickling: If true, use Python's unsafe pickle loader for the
69+
legacy pickle migration step. Only use this with a trusted source
70+
database.
6471
6572
Raises:
6673
RuntimeError: If source_db_url and dest_db_url are the same, or if no
@@ -113,7 +120,14 @@ def upgrade(source_db_url: str, dest_db_url: str):
113120
logger.info(
114121
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
115122
)
116-
migrate_func(in_url, out_url)
123+
if migrate_func is migrate_from_sqlalchemy_pickle.migrate:
124+
migrate_func(
125+
in_url,
126+
out_url,
127+
allow_unsafe_unpickling=allow_unsafe_unpickling,
128+
)
129+
else:
130+
migrate_func(in_url, out_url)
117131
logger.info("Finished migration step to schema %s.", end_version)
118132
# The output of this step becomes the input for the next step.
119133
in_url = out_url

tests/unittests/cli/utils/test_cli_tools_click.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,53 @@ def test_cli_web_passes_service_uris(
11931193
assert called_kwargs.get("memory_service_uri") == "rag://mycorpus"
11941194

11951195

1196+
@pytest.mark.parametrize(
1197+
"flag",
1198+
["--allow-unsafe-unpickling", "--allow_unsafe_unpickling"],
1199+
)
1200+
def test_cli_migrate_session_allows_unsafe_unpickling_flag(
1201+
monkeypatch: pytest.MonkeyPatch, flag: str
1202+
) -> None:
1203+
calls: list[dict[str, Any]] = []
1204+
1205+
def fake_upgrade(
1206+
source_db_url: str,
1207+
dest_db_url: str,
1208+
*,
1209+
allow_unsafe_unpickling: bool = False,
1210+
) -> None:
1211+
calls.append({
1212+
"source_db_url": source_db_url,
1213+
"dest_db_url": dest_db_url,
1214+
"allow_unsafe_unpickling": allow_unsafe_unpickling,
1215+
})
1216+
1217+
monkeypatch.setattr(
1218+
"google.adk.sessions.migration.migration_runner.upgrade",
1219+
fake_upgrade,
1220+
)
1221+
1222+
result = CliRunner().invoke(
1223+
cli_tools_click.main,
1224+
[
1225+
"migrate",
1226+
"session",
1227+
"--source_db_url",
1228+
"sqlite:///source.db",
1229+
"--dest_db_url",
1230+
"sqlite:///dest.db",
1231+
flag,
1232+
],
1233+
)
1234+
1235+
assert result.exit_code == 0, (result.output, repr(result.exception))
1236+
assert calls == [{
1237+
"source_db_url": "sqlite:///source.db",
1238+
"dest_db_url": "sqlite:///dest.db",
1239+
"allow_unsafe_unpickling": True,
1240+
}]
1241+
1242+
11961243
def test_cli_eval_with_eval_set_file_path(
11971244
mock_load_eval_set_from_file,
11981245
mock_get_root_agent,

0 commit comments

Comments
 (0)