1818import argparse
1919from datetime import datetime
2020from datetime import timezone
21+ import io
2122import json
2223import logging
2324import pickle
3738
3839logger = logging .getLogger ("google_adk." + __name__ )
3940
41+ _ALLOWED_PICKLE_GLOBALS = {
42+ # Basic types/containers
43+ ("builtins" , "dict" ),
44+ ("builtins" , "list" ),
45+ ("builtins" , "set" ),
46+ ("builtins" , "tuple" ),
47+ ("builtins" , "str" ),
48+ ("builtins" , "bytes" ),
49+ ("builtins" , "bytearray" ),
50+ ("builtins" , "int" ),
51+ ("builtins" , "float" ),
52+ ("builtins" , "bool" ),
53+ ("datetime" , "datetime" ),
54+ ("datetime" , "timedelta" ),
55+ ("datetime" , "timezone" ),
56+ # Expected pickled payload for v0 session schema events.
57+ ("fastapi.openapi.models" , "APIKey" ),
58+ ("fastapi.openapi.models" , "APIKeyIn" ),
59+ ("fastapi.openapi.models" , "HTTPBase" ),
60+ ("fastapi.openapi.models" , "HTTPBearer" ),
61+ ("fastapi.openapi.models" , "OAuth2" ),
62+ ("fastapi.openapi.models" , "OAuthFlow" ),
63+ ("fastapi.openapi.models" , "OAuthFlowAuthorizationCode" ),
64+ ("fastapi.openapi.models" , "OAuthFlowClientCredentials" ),
65+ ("fastapi.openapi.models" , "OAuthFlowImplicit" ),
66+ ("fastapi.openapi.models" , "OAuthFlowPassword" ),
67+ ("fastapi.openapi.models" , "OAuthFlows" ),
68+ ("fastapi.openapi.models" , "OpenIdConnect" ),
69+ ("fastapi.openapi.models" , "SecurityBase" ),
70+ ("fastapi.openapi.models" , "SecurityScheme" ),
71+ ("fastapi.openapi.models" , "SecuritySchemeType" ),
72+ ("google.adk.auth.auth_credential" , "AuthCredential" ),
73+ ("google.adk.auth.auth_credential" , "AuthCredentialTypes" ),
74+ ("google.adk.auth.auth_credential" , "HttpAuth" ),
75+ ("google.adk.auth.auth_credential" , "HttpCredentials" ),
76+ ("google.adk.auth.auth_credential" , "OAuth2Auth" ),
77+ ("google.adk.auth.auth_credential" , "ServiceAccountCredential" ),
78+ ("google.adk.auth.auth_schemes" , "CustomAuthScheme" ),
79+ ("google.adk.auth.auth_schemes" , "ExtendedOAuth2" ),
80+ ("google.adk.auth.auth_schemes" , "OAuthGrantType" ),
81+ ("google.adk.auth.auth_schemes" , "OpenIdConnectWithConfig" ),
82+ ("google.adk.auth.auth_tool" , "AuthConfig" ),
83+ ("google.adk.events.event_actions" , "EventActions" ),
84+ ("google.adk.events.event_actions" , "EventCompaction" ),
85+ ("google.adk.events.ui_widget" , "UiWidget" ),
86+ ("google.adk.tools.tool_confirmation" , "ToolConfirmation" ),
87+ ("google.genai.types" , "Blob" ),
88+ ("google.genai.types" , "CodeExecutionResult" ),
89+ ("google.genai.types" , "Content" ),
90+ ("google.genai.types" , "ExecutableCode" ),
91+ ("google.genai.types" , "FileData" ),
92+ ("google.genai.types" , "FunctionCall" ),
93+ ("google.genai.types" , "FunctionResponse" ),
94+ ("google.genai.types" , "FunctionResponseBlob" ),
95+ ("google.genai.types" , "FunctionResponseFileData" ),
96+ ("google.genai.types" , "FunctionResponsePart" ),
97+ ("google.genai.types" , "Part" ),
98+ ("google.genai.types" , "PartMediaResolution" ),
99+ ("google.genai.types" , "VideoMetadata" ),
100+ }
101+
102+
103+ class _RestrictedUnpickler (pickle .Unpickler ):
104+ """Restricted unpickler for migrating legacy v0 schema actions.
105+
106+ The v0 session schema stored `EventActions` as a pickled blob. During
107+ migration we treat the raw bytes read from the source DB as untrusted input
108+ and only allow the minimum set of safe globals needed to reconstruct
109+ `EventActions`.
110+ """
111+
112+ def find_class (self , module : str , name : str ) -> Any : # noqa: ANN001
113+ if (module , name ) in _ALLOWED_PICKLE_GLOBALS :
114+ return super ().find_class (module , name )
115+ raise pickle .UnpicklingError (
116+ f"Blocked global during migration unpickle: { module } .{ name } "
117+ )
118+
119+
120+ def _restricted_pickle_loads (
121+ data : bytes , * , allow_unsafe_unpickling : bool = False
122+ ) -> Any :
123+ """Load a pickle payload using the restricted unpickler by default."""
124+ if allow_unsafe_unpickling :
125+ return pickle .loads (data )
126+ return _RestrictedUnpickler (io .BytesIO (data )).load ()
127+
40128
41129def _to_datetime_obj (val : Any ) -> datetime | Any :
42130 """Converts string to datetime if needed."""
@@ -51,15 +139,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any:
51139 return val
52140
53141
54- def _row_to_event (row : dict ) -> Event :
142+ def _row_to_event (
143+ row : dict [str , Any ], * , allow_unsafe_unpickling : bool = False
144+ ) -> Event :
55145 """Converts event row (dict) to event object, handling missing columns and deserializing."""
56146
57147 actions_val = row .get ("actions" )
58148 actions = None
59149 if actions_val is not None :
60150 try :
61151 if isinstance (actions_val , bytes ):
62- actions = pickle .loads (actions_val )
152+ actions = _restricted_pickle_loads (
153+ actions_val , allow_unsafe_unpickling = allow_unsafe_unpickling
154+ )
63155 else : # for spanner - it might return object directly
64156 actions = actions_val
65157 except Exception as e :
@@ -75,17 +167,25 @@ def _row_to_event(row: dict) -> Event:
75167 else :
76168 actions = EventActions ()
77169
78- def _safe_json_load (val ):
79- data = None
170+ def _safe_json_load (val : Any ) -> dict [str , Any ] | None :
80171 if isinstance (val , str ):
81172 try :
82173 data = json .loads (val )
83174 except json .JSONDecodeError :
84175 logger .warning (f"Failed to decode JSON for event { row .get ('id' )} " )
85176 return None
86177 elif isinstance (val , dict ):
87- data = val # for postgres JSONB
88- return data
178+ return val # for postgres JSONB
179+ else :
180+ return None
181+
182+ if isinstance (data , dict ):
183+ return data
184+ logger .warning (
185+ f"Expected JSON object for event { row .get ('id' )} , got"
186+ f" { type (data ).__name__ } ."
187+ )
188+ return None
89189
90190 content_dict = _safe_json_load (row .get ("content" ))
91191 grounding_metadata_dict = _safe_json_load (row .get ("grounding_metadata" ))
@@ -147,23 +247,31 @@ def _safe_json_load(val):
147247 )
148248
149249
150- def _get_state_dict (state_val : Any ) -> dict :
250+ def _get_state_dict (state_val : Any ) -> dict [ str , Any ] :
151251 """Safely load dict from JSON string or return dict if already dict."""
152252 if isinstance (state_val , dict ):
153253 return state_val
154254 if isinstance (state_val , str ):
155255 try :
156- return json .loads (state_val )
256+ data = json .loads (state_val )
157257 except json .JSONDecodeError :
158258 logger .warning (
159259 "Failed to parse state JSON string, defaulting to empty dict."
160260 )
161261 return {}
262+ if isinstance (data , dict ):
263+ return data
264+ logger .warning ("State JSON was not an object, defaulting to empty dict." )
265+ return {}
162266 return {}
163267
164268
165269# --- Migration Logic ---
166- def migrate (source_db_url : str , dest_db_url : str ):
270+ def migrate (
271+ source_db_url : str ,
272+ dest_db_url : str ,
273+ allow_unsafe_unpickling : bool = False ,
274+ ) -> None :
167275 """Migrates data from old pickle schema to new JSON schema."""
168276 # Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
169277 # This allows users to provide URLs like 'postgresql+asyncpg://...' and have
@@ -172,6 +280,11 @@ def migrate(source_db_url: str, dest_db_url: str):
172280 dest_sync_url = _schema_check_utils .to_sync_url (dest_db_url )
173281
174282 logger .info (f"Connecting to source database: { source_db_url } " )
283+ if allow_unsafe_unpickling :
284+ logger .warning (
285+ "Unsafe pickle migration mode is enabled. Only use this with a trusted"
286+ " source database."
287+ )
175288 try :
176289 source_engine = create_engine (source_sync_url )
177290 SourceSession = sessionmaker (bind = source_engine )
@@ -265,7 +378,10 @@ def migrate(source_db_url: str, dest_db_url: str):
265378 text ("SELECT * FROM events" )
266379 ).mappings ():
267380 try :
268- event_obj = _row_to_event (dict (row ))
381+ event_obj = _row_to_event (
382+ dict (row ),
383+ allow_unsafe_unpickling = allow_unsafe_unpickling ,
384+ )
269385 new_event = v1 .StorageEvent (
270386 id = event_obj .id ,
271387 app_name = row ["app_name" ],
@@ -309,9 +425,22 @@ def migrate(source_db_url: str, dest_db_url: str):
309425 required = True ,
310426 help = "SQLAlchemy URL of destination database" ,
311427 )
428+ parser .add_argument (
429+ "--allow_unsafe_unpickling" ,
430+ "--allow-unsafe-unpickling" ,
431+ action = "store_true" ,
432+ help = (
433+ "Allow legacy pickle payloads to use Python's unsafe pickle loader."
434+ " Only use this with a trusted source database."
435+ ),
436+ )
312437 args = parser .parse_args ()
313438 try :
314- migrate (args .source_db_url , args .dest_db_url )
439+ migrate (
440+ args .source_db_url ,
441+ args .dest_db_url ,
442+ allow_unsafe_unpickling = args .allow_unsafe_unpickling ,
443+ )
315444 except Exception as e :
316445 logger .error (f"Migration failed: { e } " )
317446 sys .exit (1 )
0 commit comments