Skip to content

Commit 9db48ce

Browse files
White-MouseDeanChensj
authored andcommitted
fix(migration): restrict unpickling of v0 actions blobs
## What The v0 session schema stored event actions as pickled blobs. The migration helper reads raw bytes via `SELECT * FROM events` and previously used `pickle.loads(...)` directly. This PR replaces the default load path with a restricted unpickler allowlist for builtin containers/primitives, standard ADK `EventActions` payloads, nested ADK core action types (`AuthConfig`, `ToolConfirmation`, `EventCompaction`), and the `google.genai.types.Content` / `Part` dependency classes that normal compaction payloads require. It also adds an explicit trust toggle for legacy databases that contain custom Python objects in `state_delta` or other `Any` fields: - Python API: `migrate(..., allow_unsafe_unpickling=True)` - Migration runner: `upgrade(..., allow_unsafe_unpickling=True)` - CLI: `adk migrate session --allow_unsafe_unpickling ...` - Direct script: `--allow_unsafe_unpickling` / `--allow-unsafe-unpickling` ## Why `pickle` is not safe for untrusted inputs. Migration tooling often runs against restored/backed-up DB files or shared storage; failing closed by default reduces the blast radius if the source DB contents are compromised. The opt-in flag keeps compatibility for users who trust their source database and need the original unsafe pickle behavior for custom legacy objects. ## Associated Issue / Background No existing GitHub issue is linked. This was found while reviewing the v0-to-v1 migration path for unsafe deserialization risks in legacy session data. ## Compatibility / fail-closed boundary Normal v0 `EventActions` payloads made from primitive/container fields continue to migrate. The allowlist now also covers common nested ADK action models requested during review, including requested auth configs, requested tool confirmations, and event compaction content. Payloads that require globals outside the explicit allowlist still fail closed by default: the migration logs a warning and falls back to empty `EventActions()` for that event. Users can opt into the previous unsafe pickle behavior only when they trust the source database. ## Verification - `uv run pytest tests/unittests/sessions/migration/test_migration.py` - `22 passed, 4 warnings` - `uv run mypy src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py src/google/adk/sessions/migration/migration_runner.py` - `Success: no issues found in 2 source files` - `uv run pre-commit run --files src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py src/google/adk/sessions/migration/migration_runner.py src/google/adk/cli/cli_tools_click.py tests/unittests/sessions/migration/test_migration.py` - passed - `git diff --check` - passed Merge google#5866 Change-Id: I2f66069cb301887fbf7147dbe758b60ec2242d80
1 parent dc32c0a commit 9db48ce

5 files changed

Lines changed: 533 additions & 15 deletions

File tree

src/google/adk/cli/cli_tools_click.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,15 +2205,33 @@ def migrate():
22052205
default="INFO",
22062206
help="Optional. Set the logging level",
22072207
)
2208+
@click.option( # type: ignore[untyped-decorator]
2209+
"--allow-unsafe-unpickling",
2210+
"--allow_unsafe_unpickling",
2211+
is_flag=True,
2212+
default=False,
2213+
help=(
2214+
"Optional. Allow unsafe pickle loading for trusted legacy session"
2215+
" databases."
2216+
),
2217+
)
22082218
def cli_migrate_session(
2209-
*, source_db_url: str, dest_db_url: str, log_level: str
2219+
*,
2220+
source_db_url: str,
2221+
dest_db_url: str,
2222+
log_level: str,
2223+
allow_unsafe_unpickling: bool,
22102224
):
22112225
"""Migrates a session database to the latest schema version."""
22122226
logs.setup_adk_logger(getattr(logging, log_level.upper()))
22132227
try:
22142228
from ..sessions.migration import migration_runner
22152229

2216-
migration_runner.upgrade(source_db_url, dest_db_url)
2230+
migration_runner.upgrade(
2231+
source_db_url,
2232+
dest_db_url,
2233+
allow_unsafe_unpickling=allow_unsafe_unpickling,
2234+
)
22172235
click.secho("Migration check and upgrade process finished.", fg="green")
22182236
except Exception as e:
22192237
click.secho(f"Migration failed: {e}", fg="red", err=True)

src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py

Lines changed: 140 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import argparse
1919
from datetime import datetime
2020
from datetime import timezone
21+
import io
2122
import json
2223
import logging
2324
import pickle
@@ -37,6 +38,93 @@
3738

3839
logger = logging.getLogger("google_adk." + __name__)
3940

41+
_ALLOWED_PICKLE_GLOBALS: set[tuple[str, str]] = {
42+
# Builtin containers/primitives.
43+
("builtins", "dict"),
44+
("builtins", "list"),
45+
("builtins", "set"),
46+
("builtins", "tuple"),
47+
("builtins", "str"),
48+
("builtins", "bytes"),
49+
("builtins", "bytearray"),
50+
("builtins", "int"),
51+
("builtins", "float"),
52+
("builtins", "bool"),
53+
("datetime", "datetime"),
54+
("datetime", "timedelta"),
55+
("datetime", "timezone"),
56+
# Expected pickled payload for v0 session schema events.
57+
("fastapi.openapi.models", "APIKey"),
58+
("fastapi.openapi.models", "APIKeyIn"),
59+
("fastapi.openapi.models", "HTTPBase"),
60+
("fastapi.openapi.models", "HTTPBearer"),
61+
("fastapi.openapi.models", "OAuth2"),
62+
("fastapi.openapi.models", "OAuthFlow"),
63+
("fastapi.openapi.models", "OAuthFlowAuthorizationCode"),
64+
("fastapi.openapi.models", "OAuthFlowClientCredentials"),
65+
("fastapi.openapi.models", "OAuthFlowImplicit"),
66+
("fastapi.openapi.models", "OAuthFlowPassword"),
67+
("fastapi.openapi.models", "OAuthFlows"),
68+
("fastapi.openapi.models", "OpenIdConnect"),
69+
("fastapi.openapi.models", "SecurityBase"),
70+
("fastapi.openapi.models", "SecurityScheme"),
71+
("fastapi.openapi.models", "SecuritySchemeType"),
72+
("google.adk.auth.auth_credential", "AuthCredential"),
73+
("google.adk.auth.auth_credential", "AuthCredentialTypes"),
74+
("google.adk.auth.auth_credential", "HttpAuth"),
75+
("google.adk.auth.auth_credential", "HttpCredentials"),
76+
("google.adk.auth.auth_credential", "OAuth2Auth"),
77+
("google.adk.auth.auth_credential", "ServiceAccountCredential"),
78+
("google.adk.auth.auth_schemes", "CustomAuthScheme"),
79+
("google.adk.auth.auth_schemes", "ExtendedOAuth2"),
80+
("google.adk.auth.auth_schemes", "OAuthGrantType"),
81+
("google.adk.auth.auth_schemes", "OpenIdConnectWithConfig"),
82+
("google.adk.auth.auth_tool", "AuthConfig"),
83+
("google.adk.events.event_actions", "EventActions"),
84+
("google.adk.events.event_actions", "EventCompaction"),
85+
("google.adk.events.ui_widget", "UiWidget"),
86+
("google.adk.tools.tool_confirmation", "ToolConfirmation"),
87+
("google.genai.types", "Blob"),
88+
("google.genai.types", "CodeExecutionResult"),
89+
("google.genai.types", "Content"),
90+
("google.genai.types", "ExecutableCode"),
91+
("google.genai.types", "FileData"),
92+
("google.genai.types", "FunctionCall"),
93+
("google.genai.types", "FunctionResponse"),
94+
("google.genai.types", "FunctionResponseBlob"),
95+
("google.genai.types", "FunctionResponseFileData"),
96+
("google.genai.types", "FunctionResponsePart"),
97+
("google.genai.types", "Part"),
98+
("google.genai.types", "PartMediaResolution"),
99+
("google.genai.types", "VideoMetadata"),
100+
}
101+
102+
103+
class _RestrictedUnpickler(pickle.Unpickler):
104+
"""Restricted unpickler for migrating legacy v0 schema actions.
105+
106+
The v0 session schema stored `EventActions` as a pickled blob. During
107+
migration we treat the raw bytes read from the source DB as untrusted input
108+
and only allow the minimum set of safe globals needed to reconstruct
109+
`EventActions`.
110+
"""
111+
112+
def find_class(self, module: str, name: str) -> Any: # noqa: ANN001
113+
if (module, name) in _ALLOWED_PICKLE_GLOBALS:
114+
return super().find_class(module, name)
115+
raise pickle.UnpicklingError(
116+
f"Blocked global during migration unpickle: {module}.{name}"
117+
)
118+
119+
120+
def _restricted_pickle_loads(
121+
data: bytes, *, allow_unsafe_unpickling: bool = False
122+
) -> Any:
123+
"""Load a pickle payload using the restricted unpickler by default."""
124+
if allow_unsafe_unpickling:
125+
return pickle.loads(data)
126+
return _RestrictedUnpickler(io.BytesIO(data)).load()
127+
40128

41129
def _to_datetime_obj(val: Any) -> datetime | Any:
42130
"""Converts string to datetime if needed."""
@@ -51,15 +139,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any:
51139
return val
52140

53141

54-
def _row_to_event(row: dict) -> Event:
142+
def _row_to_event(
143+
row: dict[str, Any], *, allow_unsafe_unpickling: bool = False
144+
) -> Event:
55145
"""Converts event row (dict) to event object, handling missing columns and deserializing."""
56146

57147
actions_val = row.get("actions")
58148
actions = None
59149
if actions_val is not None:
60150
try:
61151
if isinstance(actions_val, bytes):
62-
actions = pickle.loads(actions_val)
152+
actions = _restricted_pickle_loads(
153+
actions_val, allow_unsafe_unpickling=allow_unsafe_unpickling
154+
)
63155
else: # for spanner - it might return object directly
64156
actions = actions_val
65157
except Exception as e:
@@ -75,17 +167,25 @@ def _row_to_event(row: dict) -> Event:
75167
else:
76168
actions = EventActions()
77169

78-
def _safe_json_load(val):
79-
data = None
170+
def _safe_json_load(val: Any) -> dict[str, Any] | None:
80171
if isinstance(val, str):
81172
try:
82173
data = json.loads(val)
83174
except json.JSONDecodeError:
84175
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
85176
return None
86177
elif isinstance(val, dict):
87-
data = val # for postgres JSONB
88-
return data
178+
return val # for postgres JSONB
179+
else:
180+
return None
181+
182+
if isinstance(data, dict):
183+
return data
184+
logger.warning(
185+
f"Expected JSON object for event {row.get('id')}, got"
186+
f" {type(data).__name__}."
187+
)
188+
return None
89189

90190
content_dict = _safe_json_load(row.get("content"))
91191
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
@@ -147,23 +247,31 @@ def _safe_json_load(val):
147247
)
148248

149249

150-
def _get_state_dict(state_val: Any) -> dict:
250+
def _get_state_dict(state_val: Any) -> dict[str, Any]:
151251
"""Safely load dict from JSON string or return dict if already dict."""
152252
if isinstance(state_val, dict):
153253
return state_val
154254
if isinstance(state_val, str):
155255
try:
156-
return json.loads(state_val)
256+
data = json.loads(state_val)
157257
except json.JSONDecodeError:
158258
logger.warning(
159259
"Failed to parse state JSON string, defaulting to empty dict."
160260
)
161261
return {}
262+
if isinstance(data, dict):
263+
return data
264+
logger.warning("State JSON was not an object, defaulting to empty dict.")
265+
return {}
162266
return {}
163267

164268

165269
# --- Migration Logic ---
166-
def migrate(source_db_url: str, dest_db_url: str):
270+
def migrate(
271+
source_db_url: str,
272+
dest_db_url: str,
273+
allow_unsafe_unpickling: bool = False,
274+
) -> None:
167275
"""Migrates data from old pickle schema to new JSON schema."""
168276
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
169277
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
@@ -172,6 +280,11 @@ def migrate(source_db_url: str, dest_db_url: str):
172280
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)
173281

174282
logger.info(f"Connecting to source database: {source_db_url}")
283+
if allow_unsafe_unpickling:
284+
logger.warning(
285+
"Unsafe pickle migration mode is enabled. Only use this with a trusted"
286+
" source database."
287+
)
175288
try:
176289
source_engine = create_engine(source_sync_url)
177290
SourceSession = sessionmaker(bind=source_engine)
@@ -265,7 +378,10 @@ def migrate(source_db_url: str, dest_db_url: str):
265378
text("SELECT * FROM events")
266379
).mappings():
267380
try:
268-
event_obj = _row_to_event(dict(row))
381+
event_obj = _row_to_event(
382+
dict(row),
383+
allow_unsafe_unpickling=allow_unsafe_unpickling,
384+
)
269385
new_event = v1.StorageEvent(
270386
id=event_obj.id,
271387
app_name=row["app_name"],
@@ -309,9 +425,22 @@ def migrate(source_db_url: str, dest_db_url: str):
309425
required=True,
310426
help="SQLAlchemy URL of destination database",
311427
)
428+
parser.add_argument(
429+
"--allow_unsafe_unpickling",
430+
"--allow-unsafe-unpickling",
431+
action="store_true",
432+
help=(
433+
"Allow legacy pickle payloads to use Python's unsafe pickle loader."
434+
" Only use this with a trusted source database."
435+
),
436+
)
312437
args = parser.parse_args()
313438
try:
314-
migrate(args.source_db_url, args.dest_db_url)
439+
migrate(
440+
args.source_db_url,
441+
args.dest_db_url,
442+
allow_unsafe_unpickling=args.allow_unsafe_unpickling,
443+
)
315444
except Exception as e:
316445
logger.error(f"Migration failed: {e}")
317446
sys.exit(1)

src/google/adk/sessions/migration/migration_runner.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION
4343

4444

45-
def upgrade(source_db_url: str, dest_db_url: str):
45+
def upgrade(
46+
source_db_url: str,
47+
dest_db_url: str,
48+
allow_unsafe_unpickling: bool = False,
49+
) -> None:
4650
"""Migrates a database from its current version to the latest version.
4751
4852
If the source database schema is older than the latest version, this
@@ -61,6 +65,9 @@ def upgrade(source_db_url: str, dest_db_url: str):
6165
source_db_url: The SQLAlchemy URL of the database to migrate from.
6266
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
6367
different from source_db_url.
68+
allow_unsafe_unpickling: If true, use Python's unsafe pickle loader for the
69+
legacy pickle migration step. Only use this with a trusted source
70+
database.
6471
6572
Raises:
6673
RuntimeError: If source_db_url and dest_db_url are the same, or if no
@@ -113,7 +120,14 @@ def upgrade(source_db_url: str, dest_db_url: str):
113120
logger.info(
114121
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
115122
)
116-
migrate_func(in_url, out_url)
123+
if migrate_func is migrate_from_sqlalchemy_pickle.migrate:
124+
migrate_func(
125+
in_url,
126+
out_url,
127+
allow_unsafe_unpickling=allow_unsafe_unpickling,
128+
)
129+
else:
130+
migrate_func(in_url, out_url)
117131
logger.info("Finished migration step to schema %s.", end_version)
118132
# The output of this step becomes the input for the next step.
119133
in_url = out_url

tests/unittests/cli/utils/test_cli_tools_click.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,53 @@ def test_cli_web_passes_service_uris(
11931193
assert called_kwargs.get("memory_service_uri") == "rag://mycorpus"
11941194

11951195

1196+
@pytest.mark.parametrize(
1197+
"flag",
1198+
["--allow-unsafe-unpickling", "--allow_unsafe_unpickling"],
1199+
)
1200+
def test_cli_migrate_session_allows_unsafe_unpickling_flag(
1201+
monkeypatch: pytest.MonkeyPatch, flag: str
1202+
) -> None:
1203+
calls: list[dict[str, Any]] = []
1204+
1205+
def fake_upgrade(
1206+
source_db_url: str,
1207+
dest_db_url: str,
1208+
*,
1209+
allow_unsafe_unpickling: bool = False,
1210+
) -> None:
1211+
calls.append({
1212+
"source_db_url": source_db_url,
1213+
"dest_db_url": dest_db_url,
1214+
"allow_unsafe_unpickling": allow_unsafe_unpickling,
1215+
})
1216+
1217+
monkeypatch.setattr(
1218+
"google.adk.sessions.migration.migration_runner.upgrade",
1219+
fake_upgrade,
1220+
)
1221+
1222+
result = CliRunner().invoke(
1223+
cli_tools_click.main,
1224+
[
1225+
"migrate",
1226+
"session",
1227+
"--source_db_url",
1228+
"sqlite:///source.db",
1229+
"--dest_db_url",
1230+
"sqlite:///dest.db",
1231+
flag,
1232+
],
1233+
)
1234+
1235+
assert result.exit_code == 0, (result.output, repr(result.exception))
1236+
assert calls == [{
1237+
"source_db_url": "sqlite:///source.db",
1238+
"dest_db_url": "sqlite:///dest.db",
1239+
"allow_unsafe_unpickling": True,
1240+
}]
1241+
1242+
11961243
def test_cli_eval_with_eval_set_file_path(
11971244
mock_load_eval_set_from_file,
11981245
mock_get_root_agent,

0 commit comments

Comments
 (0)