Skip to content

Commit 014d922

Browse files
committed
fix(memory): address database memory review feedback
1 parent dbfc998 commit 014d922

3 files changed

Lines changed: 68 additions & 20 deletions

File tree

src/google/adk_community/memory/database_memory_service.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
from google.adk.memory.memory_entry import MemoryEntry
3434
from google.genai import types
3535
from sqlalchemy import delete
36+
from sqlalchemy import func
3637
from sqlalchemy import select
38+
from sqlalchemy.dialects.mysql import insert as mysql_insert
39+
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
40+
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
3741
from sqlalchemy.engine import make_url
3842
from sqlalchemy.exc import ArgumentError
3943
from sqlalchemy.ext.asyncio import async_sessionmaker
@@ -57,6 +61,7 @@
5761
logger = logging.getLogger('google_adk.' + __name__)
5862

5963
_SQLITE_DIALECT = 'sqlite'
64+
_MYSQL_DIALECTS = frozenset({'mysql', 'mariadb'})
6065

6166

6267
def _format_timestamp(timestamp: float) -> str:
@@ -378,21 +383,43 @@ async def set_scratchpad(
378383
"""
379384
await self._prepare_tables()
380385
async with self._session() as sql:
381-
existing = await sql.get(
382-
StorageScratchpadKV, (app_name, user_id, session_id, key)
383-
)
384-
if existing is not None:
385-
existing.value_json = value
386-
else:
387-
sql.add(
388-
StorageScratchpadKV(
389-
app_name=app_name,
390-
user_id=user_id,
391-
session_id=session_id,
392-
key=key,
393-
value_json=value,
394-
)
386+
values = {
387+
'app_name': app_name,
388+
'user_id': user_id,
389+
'session_id': session_id,
390+
'key': key,
391+
'value_json': value,
392+
}
393+
dialect_name = sql.get_bind().dialect.name
394+
if dialect_name == 'postgresql':
395+
stmt = postgresql_insert(StorageScratchpadKV).values(**values)
396+
stmt = stmt.on_conflict_do_update(
397+
index_elements=['app_name', 'user_id', 'session_id', 'key'],
398+
set_={
399+
'value_json': stmt.excluded.value_json,
400+
'updated_at': func.now(),
401+
},
402+
)
403+
await sql.execute(stmt)
404+
elif dialect_name == _SQLITE_DIALECT:
405+
stmt = sqlite_insert(StorageScratchpadKV).values(**values)
406+
stmt = stmt.on_conflict_do_update(
407+
index_elements=['app_name', 'user_id', 'session_id', 'key'],
408+
set_={
409+
'value_json': stmt.excluded.value_json,
410+
'updated_at': func.now(),
411+
},
395412
)
413+
await sql.execute(stmt)
414+
elif dialect_name in _MYSQL_DIALECTS:
415+
stmt = mysql_insert(StorageScratchpadKV).values(**values)
416+
stmt = stmt.on_duplicate_key_update(
417+
value_json=stmt.inserted.value_json,
418+
updated_at=func.now(),
419+
)
420+
await sql.execute(stmt)
421+
else:
422+
await sql.merge(StorageScratchpadKV(**values))
396423

397424
async def get_scratchpad(
398425
self,

src/google/adk_community/memory/memory_search_backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
if TYPE_CHECKING:
3131
from sqlalchemy.ext.asyncio import AsyncSession
3232

33-
_ILIKE_DIALECTS = frozenset({'postgresql', 'mysql', 'mariadb'})
33+
_ILIKE_DIALECTS = frozenset({'postgresql'})
3434

3535

3636
class MemorySearchBackend(ABC):
@@ -68,8 +68,8 @@ class KeywordSearchBackend(MemorySearchBackend):
6868
2. Try an AND predicate (all tokens must appear) — return if found.
6969
3. Fall back to OR (any token matches) if AND yields nothing.
7070
71-
Uses ILIKE on PostgreSQL/MySQL/MariaDB and LIKE on SQLite
72-
(case-insensitive by default collation).
71+
Uses ILIKE on PostgreSQL and LIKE on other dialects
72+
(case-insensitive by default for common SQLite/MySQL/MariaDB collations).
7373
"""
7474

7575
async def search(
@@ -88,9 +88,7 @@ async def search(
8888
tokens = [
8989
cleaned
9090
for raw in query.split()
91-
if raw.strip()
92-
for cleaned in [re.sub(r'[^\w]', '', raw).lower()]
93-
if cleaned
91+
if (cleaned := re.sub(r'[^\w]', '', raw).lower())
9492
]
9593
if not tokens:
9694
return []

tests/unittests/memory/test_database_memory_service.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import asyncio
1920
from collections.abc import Sequence
2021
import time
2122
from typing import Any
@@ -278,6 +279,28 @@ async def test_scratchpad_kv_overwrite(svc):
278279
assert val == 'new'
279280

280281

282+
@pytest.mark.asyncio
283+
async def test_scratchpad_kv_concurrent_set_same_new_key(tmp_path):
284+
db_path = tmp_path / 'scratchpad.db'
285+
svc = DatabaseMemoryService(f'sqlite+aiosqlite:///{db_path}')
286+
287+
await asyncio.gather(*[
288+
svc.set_scratchpad(
289+
app_name=_APP,
290+
user_id=_USER,
291+
session_id=_SESSION,
292+
key='shared',
293+
value=i,
294+
)
295+
for i in range(10)
296+
])
297+
298+
val = await svc.get_scratchpad(
299+
app_name=_APP, user_id=_USER, session_id=_SESSION, key='shared'
300+
)
301+
assert val in range(10)
302+
303+
281304
@pytest.mark.asyncio
282305
async def test_scratchpad_kv_missing_returns_none(svc):
283306
val = await svc.get_scratchpad(

0 commit comments

Comments
 (0)