Skip to content

Commit b581dcc

Browse files
feat: support memory profile (#456)
1 parent 9a012b4 commit b581dcc

File tree

3 files changed

+182
-83
lines changed

3 files changed

+182
-83
lines changed

veadk/memory/short_term_memory.py

Lines changed: 117 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from functools import wraps
16-
from typing import Any, Callable, Literal
16+
from typing import TYPE_CHECKING, Any, Callable, Literal
1717

1818
from google.adk.sessions import (
1919
BaseSessionService,
@@ -34,6 +34,11 @@
3434
)
3535
from veadk.utils.logger import get_logger
3636

37+
if TYPE_CHECKING:
38+
from google.adk.events import Event
39+
40+
from veadk import Agent
41+
3742
logger = get_logger(__name__)
3843

3944

@@ -69,57 +74,6 @@ class ShortTermMemory(BaseModel):
6974
Default to `/tmp/veadk_local_database.db`.
7075
after_load_memory_callback (Callable | None):
7176
A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input.
72-
73-
Examples:
74-
### In-memory simple memory
75-
76-
You can initialize a short term memory with in-memory storage:
77-
78-
```python
79-
from veadk import Agent, Runner
80-
from veadk.memory.short_term_memory import ShortTermMemory
81-
import asyncio
82-
83-
session_id = "veadk_playground_session"
84-
85-
agent = Agent()
86-
short_term_memory = ShortTermMemory(backend="local")
87-
88-
runner = Runner(
89-
agent=agent, short_term_memory=short_term_memory)
90-
91-
# This invocation will be stored in short-term memory
92-
response = asyncio.run(runner.run(
93-
messages="My name is VeADK", session_id=session_id
94-
))
95-
print(response)
96-
97-
# The history invocation can be fetched by model
98-
response = asyncio.run(runner.run(
99-
messages="Do you remember my name?", session_id=session_id # keep the same `session_id`
100-
))
101-
print(response)
102-
```
103-
104-
### Memory with a Database URL
105-
106-
Also you can use a databasae connection URL to initialize a short-term memory:
107-
108-
```python
109-
from veadk.memory.short_term_memory import ShortTermMemory
110-
111-
short_term_memory = ShortTermMemory(db_url="...")
112-
```
113-
114-
### Memory with SQLite
115-
116-
Once you want to start the short term memory with a local SQLite, you can specify the backend to `sqlite`. It will create a local database in `local_database_path`:
117-
118-
```python
119-
from veadk.memory.short_term_memory import ShortTermMemory
120-
121-
short_term_memory = ShortTermMemory(backend="sqlite", local_database_path="")
122-
```
12377
"""
12478

12579
backend: Literal["local", "mysql", "sqlite", "postgresql", "database"] = "local"
@@ -200,37 +154,6 @@ async def create_session(
200154
201155
Returns:
202156
Session | None: The retrieved or newly created `Session` object, or `None` if the session creation failed.
203-
204-
Examples:
205-
Create a new session manually:
206-
207-
```python
208-
import asyncio
209-
210-
from veadk.memory import ShortTermMemory
211-
212-
app_name = "app_name"
213-
user_id = "user_id"
214-
session_id = "session_id"
215-
216-
short_term_memory = ShortTermMemory()
217-
218-
session = asyncio.run(
219-
short_term_memory.create_session(
220-
app_name=app_name, user_id=user_id, session_id=session_id
221-
)
222-
)
223-
224-
print(session)
225-
226-
session = asyncio.run(
227-
short_term_memory.session_service.get_session(
228-
app_name=app_name, user_id=user_id, session_id=session_id
229-
)
230-
)
231-
232-
print(session)
233-
```
234157
"""
235158
if isinstance(self._session_service, DatabaseSessionService):
236159
list_sessions_response = await self._session_service.list_sessions(
@@ -254,3 +177,114 @@ async def create_session(
254177
return await self._session_service.create_session(
255178
app_name=app_name, user_id=user_id, session_id=session_id
256179
)
180+
181+
async def generate_profile(
182+
self,
183+
app_name: str,
184+
user_id: str,
185+
session_id: str,
186+
events: list["Event"],
187+
) -> list[str]:
188+
import json
189+
190+
from veadk import Agent, Runner
191+
from veadk.memory.types import MemoryProfile
192+
from veadk.utils.misc import write_string_to_file
193+
194+
event_text = ""
195+
for event in events:
196+
event_text += f"- Event id: {event.id}\nEvent content: {event.content}\n"
197+
198+
agent = Agent(
199+
name="memory_summarizer",
200+
description="A summarizer that summarizes the memory events.",
201+
instruction="""Summarize the memory events into different groups according to the event content. An event can belong to multiple groups. You must output the summary in JSON format (Each group should have a simple name (only a-z and _ is allowed), and a list of event ids):
202+
[
203+
{
204+
"name": "",
205+
"event_ids": ["Event id here"]
206+
},
207+
{
208+
"name": "",
209+
"event_ids": ["Event id here"]
210+
}
211+
]""",
212+
model_name="deepseek-v3-2-251201",
213+
output_schema=MemoryProfile,
214+
)
215+
runner = Runner(agent=agent)
216+
217+
response = await runner.run(messages="Events are: \n" + event_text)
218+
219+
# profile path: ./profiles/memory/<app_name>/user_id/session_id/profile_name.json
220+
groups = json.loads(response)
221+
group_names = [group["name"] for group in groups]
222+
223+
for group in groups:
224+
group["event_list"] = []
225+
for event_id in group["event_ids"]:
226+
for event in events:
227+
if event.id == event_id:
228+
group["event_list"].append(event.content.model_dump_json())
229+
230+
write_string_to_file(
231+
content=json.dumps(group_names, ensure_ascii=False),
232+
file_path=f"./profiles/memory/{app_name}/{user_id}/{session_id}/profile_list.json",
233+
)
234+
235+
for group in groups:
236+
write_string_to_file(
237+
content=json.dumps(group, ensure_ascii=False),
238+
file_path=f"./profiles/memory/{app_name}/{user_id}/{session_id}/{group['name']}.json",
239+
)
240+
return group_names
241+
242+
async def compact_history_events(
243+
self,
244+
app_name: str,
245+
user_id: str,
246+
session_id: str,
247+
compact_limit: int,
248+
agent: "Agent",
249+
):
250+
# 1. generate profile
251+
# 2. compact history events
252+
# 3. append instruction and corresponding tool
253+
session = await self.session_service.get_session(
254+
app_name=app_name, user_id=user_id, session_id=session_id
255+
)
256+
257+
compact_event_num = 0
258+
compact_counter = 0
259+
for event in session.events:
260+
if event.content.role == "user":
261+
compact_counter += 1
262+
if compact_counter > compact_limit:
263+
break
264+
compact_event_num += 1
265+
266+
events_need_compact = session.events[:compact_event_num] # type: ignore
267+
268+
group_names = await self.generate_profile(
269+
app_name=app_name,
270+
user_id=user_id,
271+
session_id=session_id,
272+
events=events_need_compact,
273+
)
274+
275+
# TODO(yaozheng): directly edit the events are not work as expected,
276+
# need to check the reason later
277+
session.events = session.events[compact_event_num:] # type: ignore
278+
logger.debug(f"Compacted {compact_event_num} events.")
279+
280+
agent.instruction += f"""
281+
The session has been compacted for the first {compact_limit} events. The compacted content are divided into following groups:
282+
283+
{group_names}
284+
285+
You can call `load_history_events` to load the compacted events if you need them according to the user's request.
286+
"""
287+
288+
from veadk.tools.load_history_events import load_history_events
289+
290+
agent.tools.append(load_history_events)

veadk/memory/types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pydantic import BaseModel
16+
17+
18+
class MemoryProfile(BaseModel):
19+
name: str
20+
event_ids: list[str]

veadk/tools/load_history_events.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from pathlib import Path
17+
18+
from google.adk.tools.tool_context import ToolContext
19+
20+
21+
def load_profile(profile_path: Path) -> dict:
22+
# read file content
23+
with open(profile_path, "r") as f:
24+
content = f.read()
25+
return json.loads(content)
26+
27+
28+
def load_history_events(group_names: list[str], tool_context: ToolContext) -> dict:
29+
"""Load necessary history events by group names.
30+
31+
Args:
32+
group_names (list[str]): The list of group names to load events for.
33+
"""
34+
app_name = tool_context._invocation_context.app_name
35+
user_id = tool_context._invocation_context.user_id
36+
session_id = tool_context._invocation_context.session.id
37+
38+
events = {}
39+
for group_name in group_names:
40+
profile_path = Path(
41+
f"./profiles/memory/{app_name}/{user_id}/{session_id}/{group_name}.json"
42+
)
43+
profile = load_profile(profile_path)
44+
events[group_name] = profile.get("event_list", [])
45+
return events

0 commit comments

Comments
 (0)