Skip to content

Commit 1c5f45f

Browse files
ensure forward idempotence when resuming aborted rollout
1 parent 8a4c0c3 commit 1c5f45f

1 file changed

Lines changed: 50 additions & 2 deletions

File tree

lagent/agents/agent.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,23 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa
7070
result = hook.before_agent(self, message, session_id)
7171
if result:
7272
message = result
73-
self.update_memory(message, session_id=session_id)
73+
74+
# resume aborted rollout
75+
_message = self._scroll_buffer(message[-1], session_id)
76+
if _message is not None:
77+
if _message.finish_reason != 'abort':
78+
_message = copy.deepcopy(_message)
79+
for hook in self._hooks.values():
80+
result = hook.after_agent(self, _message, session_id)
81+
if result:
82+
_message = result
83+
return _message
84+
message[-1].extra_info['partial_response'] = _message
85+
else:
86+
self.update_memory(message, session_id=session_id)
7487
response_message = self.forward(*message, session_id=session_id, **kwargs)
88+
if _message and _message.finish_reason == 'abort':
89+
message[-1].extra_info.pop('partial_response', None)
7590
if not isinstance(response_message, AgentMessage):
7691
if isinstance(response_message, str):
7792
response_message = AgentMessage(sender=self.name, content=response_message)
@@ -183,6 +198,24 @@ def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict
183198
return self.aggregator.aggregate(self.memory.get(session_id), self.name, self.output_format, self.template)
184199
raise ValueError(f'{self.name} has no aggregator to get messages')
185200

201+
def _scroll_buffer(self, message, session_id, hash_func=lambda m: m.content):
202+
memory = self.memory and self.memory.get(session_id)
203+
if not memory:
204+
return
205+
mem = self.memory.get_memory(session_id)
206+
is_aborted = [m.finish_reason == 'abort' for m in mem]
207+
if not is_aborted.count(True):
208+
return
209+
aborted_msg_idx = is_aborted.index(True)
210+
memory.delete(range(aborted_msg_idx + 1, len(mem)))
211+
enc = hash_func(message)
212+
for i in range(0, aborted_msg_idx):
213+
if mem[i].sender == message.sender and hash_func(mem[i]) == enc:
214+
ret = mem[i + 1]
215+
if i + 1 == aborted_msg_idx:
216+
memory.delete(aborted_msg_idx)
217+
return ret
218+
186219
def __repr__(self):
187220

188221
def _rcsv_repr(agent, n_indent=1):
@@ -208,8 +241,23 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen
208241
result = hook.before_agent(self, message, session_id)
209242
if result:
210243
message = result
211-
self.update_memory(message, session_id=session_id)
244+
245+
# resume aborted rollout
246+
_message = self._scroll_buffer(message[-1], session_id)
247+
if _message is not None:
248+
if _message.finish_reason != 'abort':
249+
_message = copy.deepcopy(_message)
250+
for hook in self._hooks.values():
251+
result = hook.after_agent(self, _message, session_id)
252+
if result:
253+
_message = result
254+
return _message
255+
message[-1].extra_info['partial_response'] = _message
256+
else:
257+
self.update_memory(message, session_id=session_id)
212258
response_message = await self.forward(*message, session_id=session_id, **kwargs)
259+
if _message and _message.finish_reason == 'abort':
260+
message[-1].extra_info.pop('partial_response', None)
213261
if not isinstance(response_message, AgentMessage):
214262
if isinstance(response_message, str):
215263
response_message = AgentMessage(sender=self.name, content=response_message)

0 commit comments

Comments
 (0)