Skip to content

Commit 8391835

Browse files
authored
Merge branch 'main' into feat/add-metadata-parameter
2 parents 74b9b7c + 4b29d15 commit 8391835

File tree

4 files changed

+179
-6
lines changed

4 files changed

+179
-6
lines changed

src/google/adk/sessions/migration/_schema_check_utils.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,28 @@ def get_db_schema_version_from_connection(connection) -> str:
8282
return _get_schema_version_impl(inspector, connection)
8383

8484

85-
def _to_sync_url(db_url: str) -> str:
86-
"""Removes '+driver' from SQLAlchemy URL."""
85+
def to_sync_url(db_url: str) -> str:
86+
"""Removes '+driver' from SQLAlchemy URL.
87+
88+
This is useful when you need to use a synchronous SQLAlchemy engine with
89+
a database URL that specifies an async driver (e.g., postgresql+asyncpg://
90+
or sqlite+aiosqlite://).
91+
92+
Args:
93+
db_url: The database URL, potentially with a driver specification.
94+
95+
Returns:
96+
The database URL with the driver specification removed (e.g.,
97+
'postgresql+asyncpg://host/db' becomes 'postgresql://host/db').
98+
99+
Examples:
100+
>>> to_sync_url('postgresql+asyncpg://localhost/mydb')
101+
'postgresql://localhost/mydb'
102+
>>> to_sync_url('sqlite+aiosqlite:///path/to/db.sqlite')
103+
'sqlite:///path/to/db.sqlite'
104+
>>> to_sync_url('mysql://localhost/mydb') # No driver, returns unchanged
105+
'mysql://localhost/mydb'
106+
"""
87107
if "://" in db_url:
88108
scheme, _, rest = db_url.partition("://")
89109
if "+" in scheme:
@@ -106,7 +126,7 @@ def get_db_schema_version(db_url: str) -> str:
106126
"""
107127
engine = None
108128
try:
109-
engine = create_sync_engine(_to_sync_url(db_url))
129+
engine = create_sync_engine(to_sync_url(db_url))
110130
with engine.connect() as connection:
111131
inspector = inspect(connection)
112132
return _get_schema_version_impl(inspector, connection)

src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,23 @@ def _get_state_dict(state_val: Any) -> dict:
165165
# --- Migration Logic ---
166166
def migrate(source_db_url: str, dest_db_url: str):
167167
"""Migrates data from old pickle schema to new JSON schema."""
168+
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
169+
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
170+
# them automatically converted to 'postgresql://...' for migration.
171+
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
172+
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)
173+
168174
logger.info(f"Connecting to source database: {source_db_url}")
169175
try:
170-
source_engine = create_engine(source_db_url)
176+
source_engine = create_engine(source_sync_url)
171177
SourceSession = sessionmaker(bind=source_engine)
172178
except Exception as e:
173179
logger.error(f"Failed to connect to source database: {e}")
174180
raise RuntimeError(f"Failed to connect to source database: {e}") from e
175181

176182
logger.info(f"Connecting to destination database: {dest_db_url}")
177183
try:
178-
dest_engine = create_engine(dest_db_url)
184+
dest_engine = create_engine(dest_sync_url)
179185
v1.Base.metadata.create_all(dest_engine)
180186
DestSession = sessionmaker(bind=dest_engine)
181187
except Exception as e:

src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sys
2424

2525
from google.adk.sessions import sqlite_session_service as sss
26+
from google.adk.sessions.migration import _schema_check_utils
2627
from google.adk.sessions.schemas import v0 as v0_schema
2728
from sqlalchemy import create_engine
2829
from sqlalchemy.orm import sessionmaker
@@ -32,9 +33,14 @@
3233

3334
def migrate(source_db_url: str, dest_db_path: str):
3435
"""Migrates data from a SQLAlchemy-based SQLite DB to the new schema."""
36+
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
37+
# This allows users to provide URLs like 'sqlite+aiosqlite://...' and have
38+
# them automatically converted to 'sqlite://...' for migration.
39+
source_sync_url = _schema_check_utils.to_sync_url(source_db_url)
40+
3541
logger.info(f"Connecting to source database: {source_db_url}")
3642
try:
37-
engine = create_engine(source_db_url)
43+
engine = create_engine(source_sync_url)
3844
v0_schema.Base.metadata.create_all(
3945
engine
4046
) # Ensure tables exist for inspection

tests/unittests/sessions/migration/test_migration.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,88 @@
2323
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp
2424
from google.adk.sessions.schemas import v0
2525
from google.adk.sessions.schemas import v1
26+
import pytest
2627
from sqlalchemy import create_engine
2728
from sqlalchemy.orm import sessionmaker
2829

2930

31+
class TestToSyncUrl:
32+
"""Tests for the to_sync_url function."""
33+
34+
@pytest.mark.parametrize(
35+
"input_url,expected_url",
36+
[
37+
# PostgreSQL async drivers
38+
(
39+
"postgresql+asyncpg://localhost/mydb",
40+
"postgresql://localhost/mydb",
41+
),
42+
(
43+
"postgresql+asyncpg://user:pass@localhost:5432/mydb",
44+
"postgresql://user:pass@localhost:5432/mydb",
45+
),
46+
# PostgreSQL sync drivers (should still strip)
47+
(
48+
"postgresql+psycopg2://localhost/mydb",
49+
"postgresql://localhost/mydb",
50+
),
51+
# MySQL async drivers
52+
(
53+
"mysql+aiomysql://localhost/mydb",
54+
"mysql://localhost/mydb",
55+
),
56+
(
57+
"mysql+asyncmy://user:pass@localhost:3306/mydb",
58+
"mysql://user:pass@localhost:3306/mydb",
59+
),
60+
# SQLite async driver
61+
(
62+
"sqlite+aiosqlite:///path/to/db.sqlite",
63+
"sqlite:///path/to/db.sqlite",
64+
),
65+
(
66+
"sqlite+aiosqlite:///:memory:",
67+
"sqlite:///:memory:",
68+
),
69+
# URLs without driver specification (unchanged)
70+
(
71+
"postgresql://localhost/mydb",
72+
"postgresql://localhost/mydb",
73+
),
74+
(
75+
"mysql://localhost/mydb",
76+
"mysql://localhost/mydb",
77+
),
78+
(
79+
"sqlite:///path/to/db.sqlite",
80+
"sqlite:///path/to/db.sqlite",
81+
),
82+
# Edge cases
83+
(
84+
"sqlite:///:memory:",
85+
"sqlite:///:memory:",
86+
),
87+
# Complex URL with query parameters
88+
(
89+
"postgresql+asyncpg://user:pass@host/db?ssl=require",
90+
"postgresql://user:pass@host/db?ssl=require",
91+
),
92+
],
93+
)
94+
def test_to_sync_url(self, input_url, expected_url):
95+
"""Test that async driver specifications are correctly removed."""
96+
assert _schema_check_utils.to_sync_url(input_url) == expected_url
97+
98+
def test_to_sync_url_no_scheme_separator(self):
99+
"""Test that URLs without :// are returned unchanged."""
100+
# This is an invalid URL but the function should handle it gracefully
101+
assert _schema_check_utils.to_sync_url("not-a-url") == "not-a-url"
102+
103+
def test_to_sync_url_empty_string(self):
104+
"""Test that empty string is returned unchanged."""
105+
assert _schema_check_utils.to_sync_url("") == ""
106+
107+
30108
def test_migrate_from_sqlalchemy_pickle(tmp_path):
31109
"""Tests for migrate_from_sqlalchemy_pickle."""
32110
source_db_path = tmp_path / "source_pickle.db"
@@ -104,3 +182,66 @@ def test_migrate_from_sqlalchemy_pickle(tmp_path):
104182
assert event_res.event_data["actions"]["state_delta"] == {"skey": 4}
105183

106184
dest_session.close()
185+
186+
187+
def test_migrate_from_sqlalchemy_pickle_with_async_driver_urls(tmp_path):
188+
"""Tests that migration works with async driver URLs (fixes issue #4176).
189+
190+
Users often provide async driver URLs (e.g., postgresql+asyncpg://) since
191+
that's what ADK requires at runtime. The migration tool should handle these
192+
by automatically converting them to sync URLs.
193+
"""
194+
source_db_path = tmp_path / "source_pickle_async.db"
195+
dest_db_path = tmp_path / "dest_json_async.db"
196+
# Use async driver URLs like users would typically provide
197+
source_db_url = f"sqlite+aiosqlite:///{source_db_path}"
198+
dest_db_url = f"sqlite+aiosqlite:///{dest_db_path}"
199+
200+
# Set up source DB with old pickle schema using sync URL
201+
sync_source_url = f"sqlite:///{source_db_path}"
202+
source_engine = create_engine(sync_source_url)
203+
v0.Base.metadata.create_all(source_engine)
204+
SourceSession = sessionmaker(bind=source_engine)
205+
source_session = SourceSession()
206+
207+
# Populate source data
208+
now = datetime.now(timezone.utc)
209+
app_state = v0.StorageAppState(
210+
app_name="async_app", state={"key": "value"}, update_time=now
211+
)
212+
session = v0.StorageSession(
213+
app_name="async_app",
214+
user_id="async_user",
215+
id="async_session",
216+
state={},
217+
create_time=now,
218+
update_time=now,
219+
)
220+
source_session.add_all([app_state, session])
221+
source_session.commit()
222+
source_session.close()
223+
224+
# This should NOT raise an error about async drivers (the fix for #4176)
225+
mfsp.migrate(source_db_url, dest_db_url)
226+
227+
# Verify destination DB
228+
sync_dest_url = f"sqlite:///{dest_db_path}"
229+
dest_engine = create_engine(sync_dest_url)
230+
DestSession = sessionmaker(bind=dest_engine)
231+
dest_session = DestSession()
232+
233+
metadata = dest_session.query(v1.StorageMetadata).first()
234+
assert metadata is not None
235+
assert metadata.key == _schema_check_utils.SCHEMA_VERSION_KEY
236+
assert metadata.value == _schema_check_utils.SCHEMA_VERSION_1_JSON
237+
238+
app_state_res = dest_session.query(v1.StorageAppState).first()
239+
assert app_state_res is not None
240+
assert app_state_res.app_name == "async_app"
241+
assert app_state_res.state == {"key": "value"}
242+
243+
session_res = dest_session.query(v1.StorageSession).first()
244+
assert session_res is not None
245+
assert session_res.id == "async_session"
246+
247+
dest_session.close()

0 commit comments

Comments
 (0)