Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,22 @@ async def list_sessions(
.filter(StorageSession.user_id == user_id)
.all()
)

# Fetch states from storage
storage_app_state = sql_session.get(StorageAppState, (app_name))
storage_user_state = sql_session.get(
StorageUserState, (app_name, user_id)
)

app_state = storage_app_state.state if storage_app_state else {}
user_state = storage_user_state.state if storage_user_state else {}

sessions = []
for storage_session in results:
sessions.append(storage_session.to_session())
session_state = storage_session.state
merged_state = _merge_state(app_state, user_state, session_state)

sessions.append(storage_session.to_session(state=merged_state))
return ListSessionsResponse(sessions=sessions)

@override
Expand Down
2 changes: 1 addition & 1 deletion src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _list_sessions_impl(
for session in self.sessions[app_name][user_id].values():
copied_session = copy.deepcopy(session)
copied_session.events = []
copied_session.state = {}
copied_session = self._merge_state(app_name, user_id, copied_session)
sessions_without_events.append(copied_session)
return ListSessionsResponse(sessions=sessions_without_events)

Expand Down
14 changes: 9 additions & 5 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,28 @@ async def list_sessions(
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
path = path + f'?filter=user_id={parsed_user_id}'

api_response = await api_client.async_request(
list_sessions_api_response = await api_client.async_request(
http_method='GET',
path=path,
request_dict={},
)
api_response = _convert_api_response(api_response)
list_sessions_api_response = _convert_api_response(
list_sessions_api_response
)

# Handles empty response case
if not api_response or api_response.get('httpHeaders', None):
if not list_sessions_api_response or list_sessions_api_response.get(
'httpHeaders', None
):
return ListSessionsResponse()

sessions = []
for api_session in api_response['sessions']:
for api_session in list_sessions_api_response['sessions']:
session = Session(
app_name=app_name,
user_id=user_id,
id=api_session['name'].split('/')[-1],
state={},
state=api_session.get('sessionState', {}),
last_update_time=isoparse(api_session['updateTime']).timestamp(),
)
sessions.append(session)
Expand Down
6 changes: 5 additions & 1 deletion tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ async def test_create_and_list_sessions(service_type):
session_ids = ['session' + str(i) for i in range(5)]
for session_id in session_ids:
await session_service.create_session(
app_name=app_name, user_id=user_id, session_id=session_id
app_name=app_name,
user_id=user_id,
session_id=session_id,
state={'key': 'value' + session_id},
)

list_sessions_response = await session_service.list_sessions(
Expand All @@ -115,6 +118,7 @@ async def test_create_and_list_sessions(service_type):
sessions = list_sessions_response.sessions
for i in range(len(sessions)):
assert sessions[i].id == session_ids[i]
assert sessions[i].state == {'key': 'value' + session_ids[i]}


@pytest.mark.asyncio
Expand Down