Skip to content

Commit 000209e

Browse files
improve memory serialization and deserialization
1 parent 6e4cb10 commit 000209e

1 file changed

Lines changed: 18 additions & 15 deletions

File tree

lagent/memory/base_memory.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import Callable, Dict, List, Optional, Union
22

33
from lagent.schema import AgentMessage
4+
from lagent.utils import load_class_from_string
45

56

67
class Memory:
78

8-
_item_cls = AgentMessage
9-
109
def __init__(self, recent_n=None) -> None:
1110
self.memory: List[AgentMessage] = []
1211
self.recent_n = recent_n
@@ -25,39 +24,43 @@ def get_memory(
2524
memory = [m for i, m in enumerate(memory) if filter_func(i, m)]
2625
return memory
2726

28-
def add(self, memories: Union[List[Dict], Dict, None]) -> None:
27+
def add(self, memories: Union[List[AgentMessage | str], AgentMessage, str]) -> None:
2928
for memory in memories if isinstance(memories, (list, tuple)) else [memories]:
3029
if isinstance(memory, str):
31-
memory = self._item_cls(sender='user', content=memory)
30+
memory = AgentMessage(sender='user', content=memory)
3231
if isinstance(memory, AgentMessage):
33-
if not isinstance(memory, self._item_cls):
34-
memory = self._item_cls.model_validate(memory, from_attributes=True)
3532
self.memory.append(memory)
3633

37-
def delete(self, index: Union[List, int]) -> None:
34+
def delete(self, index: Union[List[int], int]) -> None:
3835
if isinstance(index, int):
3936
del self.memory[index]
4037
else:
4138
for i in index:
4239
del self.memory[i]
4340

44-
def load(
45-
self,
46-
memories: Union[str, Dict, List],
47-
overwrite: bool = True,
48-
) -> None:
41+
def load(self, memories: Union[str, dict, List], overwrite: bool = True) -> None:
4942
if overwrite:
5043
self.memory = []
5144
if isinstance(memories, dict):
52-
self.memory.append(self._item_cls.model_validate(memories))
45+
memories = memories.copy()
46+
_cls = (
47+
load_class_from_string(memories.pop('__model_spec__'))
48+
if '__model_spec__' in memories
49+
else AgentMessage
50+
)
51+
self.memory.append(_cls.model_validate(memories))
5352
elif isinstance(memories, list):
5453
for m in memories:
55-
self.memory.append(self._item_cls.model_validate(m))
54+
m = m.copy()
55+
_cls = load_class_from_string(m.pop('__model_spec__')) if '__model_spec__' in m else AgentMessage
56+
self.memory.append(_cls.model_validate(m))
5657
else:
5758
raise TypeError(f'{type(memories)} is not supported')
5859

5960
def save(self) -> List[dict]:
6061
memory = []
6162
for m in self.memory:
62-
memory.append(m.model_dump())
63+
m_dumped = m.model_dump()
64+
m_dumped['__model_spec__'] = f'{m.__module__}.{m.__class__.__name__}'
65+
memory.append(m_dumped)
6366
return memory

0 commit comments

Comments
 (0)