Skip to content

Commit 24342e9

Browse files
Jacksunweicopybara-github
authored andcommitted
chore: Remove temp state deltas before appending an event
PiperOrigin-RevId: 816902208
1 parent cbe60c4 commit 24342e9

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

src/google/adk/sessions/base_session_service.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import abc
1618
from typing import Any
1719
from typing import Optional
@@ -95,11 +97,24 @@ async def append_event(self, session: Session, event: Event) -> Event:
9597
"""Appends an event to a session object."""
9698
if event.partial:
9799
return event
98-
self.__update_session_state(session, event)
100+
event = self._trim_temp_delta_state(event)
101+
self._update_session_state(session, event)
99102
session.events.append(event)
100103
return event
101104

102-
def __update_session_state(self, session: Session, event: Event) -> None:
105+
def _trim_temp_delta_state(self, event: Event) -> Event:
106+
"""Removes temporary state delta keys from the event."""
107+
if not event.actions or not event.actions.state_delta:
108+
return event
109+
110+
event.actions.state_delta = {
111+
key: value
112+
for key, value in event.actions.state_delta.items()
113+
if not key.startswith(State.TEMP_PREFIX)
114+
}
115+
return event
116+
117+
def _update_session_state(self, session: Session, event: Event) -> None:
103118
"""Updates the session state based on the event."""
104119
if not event.actions or not event.actions.state_delta:
105120
return

src/google/adk/sessions/database_session_service.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
599599
if event.partial:
600600
return event
601601

602+
# Trim temp state before persisting
603+
event = self._trim_temp_delta_state(event)
604+
602605
# 1. Check if timestamp is stale
603606
# 2. Update session attributes based on event config
604607
# 3. Store event to table

tests/unittests/sessions/test_session_service.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,39 @@ async def test_append_event_with_fields(service_type):
441441
retrieved_event = retrieved_session.events[0]
442442

443443
assert retrieved_event == event
444+
445+
446+
@pytest.mark.asyncio
447+
@pytest.mark.parametrize(
448+
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
449+
)
450+
async def test_append_event_should_trim_temp_delta_state(service_type):
451+
session_service = get_session_service(service_type)
452+
app_name = 'my_app'
453+
user_id = 'user'
454+
455+
session = await session_service.create_session(
456+
app_name=app_name, user_id=user_id
457+
)
458+
459+
event = Event(
460+
invocation_id='invocation',
461+
author='user',
462+
content=types.Content(role='user', parts=[types.Part(text='text')]),
463+
actions=EventActions(
464+
state_delta={
465+
'app:key': 'app_value',
466+
'temp:key': 'temp_value',
467+
}
468+
),
469+
)
470+
471+
await session_service.append_event(session, event)
472+
473+
updated_session = await session_service.get_session(
474+
app_name=app_name, user_id=user_id, session_id=session.id
475+
)
476+
477+
last_event = updated_session.events[-1]
478+
assert 'temp:key' not in last_event.actions.state_delta
479+
assert last_event.actions.state_delta['app:key'] == 'app_value'

0 commit comments

Comments
 (0)