@@ -70,8 +70,23 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa
7070 result = hook .before_agent (self , message , session_id )
7171 if result :
7272 message = result
73- self .update_memory (message , session_id = session_id )
73+
74+ # resume aborted rollout
75+ _message = self ._scroll_buffer (message [- 1 ], session_id )
76+ if _message is not None :
77+ if _message .finish_reason != 'abort' :
78+ _message = copy .deepcopy (_message )
79+ for hook in self ._hooks .values ():
80+ result = hook .after_agent (self , _message , session_id )
81+ if result :
82+ _message = result
83+ return _message
84+ message [- 1 ].extra_info ['partial_response' ] = _message
85+ else :
86+ self .update_memory (message , session_id = session_id )
7487 response_message = self .forward (* message , session_id = session_id , ** kwargs )
88+ if _message and _message .finish_reason == 'abort' :
89+ message [- 1 ].extra_info .pop ('partial_response' , None )
7590 if not isinstance (response_message , AgentMessage ):
7691 if isinstance (response_message , str ):
7792 response_message = AgentMessage (sender = self .name , content = response_message )
@@ -183,6 +198,24 @@ def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict
183198 return self .aggregator .aggregate (self .memory .get (session_id ), self .name , self .output_format , self .template )
184199 raise ValueError (f'{ self .name } has no aggregator to get messages' )
185200
201+ def _scroll_buffer (self , message , session_id , hash_func = lambda m : m .content ):
202+ memory = self .memory and self .memory .get (session_id )
203+ if not memory :
204+ return
205+ mem = self .memory .get_memory (session_id )
206+ is_aborted = [m .finish_reason == 'abort' for m in mem ]
207+ if not is_aborted .count (True ):
208+ return
209+ aborted_msg_idx = is_aborted .index (True )
210+ memory .delete (range (aborted_msg_idx + 1 , len (mem )))
211+ enc = hash_func (message )
212+ for i in range (0 , aborted_msg_idx ):
213+ if mem [i ].sender == message .sender and hash_func (mem [i ]) == enc :
214+ ret = mem [i + 1 ]
215+ if i + 1 == aborted_msg_idx :
216+ memory .delete (aborted_msg_idx )
217+ return ret
218+
186219 def __repr__ (self ):
187220
188221 def _rcsv_repr (agent , n_indent = 1 ):
@@ -208,8 +241,23 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen
208241 result = hook .before_agent (self , message , session_id )
209242 if result :
210243 message = result
211- self .update_memory (message , session_id = session_id )
244+
245+ # resume aborted rollout
246+ _message = self ._scroll_buffer (message [- 1 ], session_id )
247+ if _message is not None :
248+ if _message .finish_reason != 'abort' :
249+ _message = copy .deepcopy (_message )
250+ for hook in self ._hooks .values ():
251+ result = hook .after_agent (self , _message , session_id )
252+ if result :
253+ _message = result
254+ return _message
255+ message [- 1 ].extra_info ['partial_response' ] = _message
256+ else :
257+ self .update_memory (message , session_id = session_id )
212258 response_message = await self .forward (* message , session_id = session_id , ** kwargs )
259+ if _message and _message .finish_reason == 'abort' :
260+ message [- 1 ].extra_info .pop ('partial_response' , None )
213261 if not isinstance (response_message , AgentMessage ):
214262 if isinstance (response_message , str ):
215263 response_message = AgentMessage (sender = self .name , content = response_message )
0 commit comments