11from typing import Callable , Dict , List , Optional , Union
22
33from lagent .schema import AgentMessage
4+ from lagent .utils import load_class_from_string
45
56
67class Memory :
78
8- _item_cls = AgentMessage
9-
109 def __init__ (self , recent_n = None ) -> None :
1110 self .memory : List [AgentMessage ] = []
1211 self .recent_n = recent_n
@@ -25,39 +24,43 @@ def get_memory(
2524 memory = [m for i , m in enumerate (memory ) if filter_func (i , m )]
2625 return memory
2726
28- def add (self , memories : Union [List [Dict ], Dict , None ]) -> None :
27+ def add (self , memories : Union [List [AgentMessage | str ], AgentMessage , str ]) -> None :
2928 for memory in memories if isinstance (memories , (list , tuple )) else [memories ]:
3029 if isinstance (memory , str ):
31- memory = self . _item_cls (sender = 'user' , content = memory )
30+ memory = AgentMessage (sender = 'user' , content = memory )
3231 if isinstance (memory , AgentMessage ):
33- if not isinstance (memory , self ._item_cls ):
34- memory = self ._item_cls .model_validate (memory , from_attributes = True )
3532 self .memory .append (memory )
3633
37- def delete (self , index : Union [List , int ]) -> None :
34+ def delete (self , index : Union [List [ int ] , int ]) -> None :
3835 if isinstance (index , int ):
3936 del self .memory [index ]
4037 else :
4138 for i in index :
4239 del self .memory [i ]
4340
44- def load (
45- self ,
46- memories : Union [str , Dict , List ],
47- overwrite : bool = True ,
48- ) -> None :
41+ def load (self , memories : Union [str , dict , List ], overwrite : bool = True ) -> None :
4942 if overwrite :
5043 self .memory = []
5144 if isinstance (memories , dict ):
52- self .memory .append (self ._item_cls .model_validate (memories ))
45+ memories = memories .copy ()
46+ _cls = (
47+ load_class_from_string (memories .pop ('__model_spec__' ))
48+ if '__model_spec__' in memories
49+ else AgentMessage
50+ )
51+ self .memory .append (_cls .model_validate (memories ))
5352 elif isinstance (memories , list ):
5453 for m in memories :
55- self .memory .append (self ._item_cls .model_validate (m ))
54+ m = m .copy ()
55+ _cls = load_class_from_string (m .pop ('__model_spec__' )) if '__model_spec__' in m else AgentMessage
56+ self .memory .append (_cls .model_validate (m ))
5657 else :
5758 raise TypeError (f'{ type (memories )} is not supported' )
5859
5960 def save (self ) -> List [dict ]:
6061 memory = []
6162 for m in self .memory :
62- memory .append (m .model_dump ())
63+ m_dumped = m .model_dump ()
64+ m_dumped ['__model_spec__' ] = f'{ m .__module__ } .{ m .__class__ .__name__ } '
65+ memory .append (m_dumped )
6366 return memory
0 commit comments