Skip to content

Commit 5219ba5

Browse files
authored
feat: implement follow-up message handling in ToolLoopAgentRunner (#5484)
* feat: implement follow-up message handling in ToolLoopAgentRunner * fix: correct import path for follow-up module in InternalAgentSubStage
1 parent 84994b5 commit 5219ba5

File tree

4 files changed

+606
-212
lines changed

4 files changed

+606
-212
lines changed

astrbot/core/agent/runners/tool_loop_agent_runner.py

Lines changed: 106 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import asyncio
12
import copy
23
import sys
34
import time
45
import traceback
56
import typing as T
6-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
78

89
from mcp.types import (
910
BlobResourceContents,
@@ -68,6 +69,14 @@ def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult":
6869
return cls(kind="cached_image", cached_image=image)
6970

7071

72+
@dataclass(slots=True)
73+
class FollowUpTicket:
74+
seq: int
75+
text: str
76+
consumed: bool = False
77+
resolved: asyncio.Event = field(default_factory=asyncio.Event)
78+
79+
7180
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
7281
@override
7382
async def reset(
@@ -139,6 +148,8 @@ async def reset(
139148
self.run_context = run_context
140149
self._stop_requested = False
141150
self._aborted = False
151+
self._pending_follow_ups: list[FollowUpTicket] = []
152+
self._follow_up_seq = 0
142153

143154
# These two are used for tool schema mode handling
144155
# We now have two modes:
@@ -277,6 +288,55 @@ def _simple_print_message_role(self, tag: str = ""):
277288
roles.append(message.role)
278289
logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")
279290

291+
def follow_up(
292+
self,
293+
*,
294+
message_text: str,
295+
) -> FollowUpTicket | None:
296+
"""Queue a follow-up message for the next tool result."""
297+
if self.done():
298+
return None
299+
text = (message_text or "").strip()
300+
if not text:
301+
return None
302+
ticket = FollowUpTicket(seq=self._follow_up_seq, text=text)
303+
self._follow_up_seq += 1
304+
self._pending_follow_ups.append(ticket)
305+
return ticket
306+
307+
def _resolve_unconsumed_follow_ups(self) -> None:
308+
if not self._pending_follow_ups:
309+
return
310+
follow_ups = self._pending_follow_ups
311+
self._pending_follow_ups = []
312+
for ticket in follow_ups:
313+
ticket.resolved.set()
314+
315+
def _consume_follow_up_notice(self) -> str:
316+
if not self._pending_follow_ups:
317+
return ""
318+
follow_ups = self._pending_follow_ups
319+
self._pending_follow_ups = []
320+
for ticket in follow_ups:
321+
ticket.consumed = True
322+
ticket.resolved.set()
323+
follow_up_lines = "\n".join(
324+
f"{idx}. {ticket.text}" for idx, ticket in enumerate(follow_ups, start=1)
325+
)
326+
return (
327+
"\n\n[SYSTEM NOTICE] User sent follow-up messages while tool execution "
328+
"was in progress. Prioritize these follow-up instructions in your next "
329+
"actions. In your very next action, briefly acknowledge to the user "
330+
"that their follow-up message(s) were received before continuing.\n"
331+
f"{follow_up_lines}"
332+
)
333+
334+
def _merge_follow_up_notice(self, content: str) -> str:
335+
notice = self._consume_follow_up_notice()
336+
if not notice:
337+
return content
338+
return f"{content}{notice}"
339+
280340
@override
281341
async def step(self):
282342
"""Process a single step of the agent.
@@ -391,6 +451,7 @@ async def step(self):
391451
type="aborted",
392452
data=AgentResponseData(chain=MessageChain(type="aborted")),
393453
)
454+
self._resolve_unconsumed_follow_ups()
394455
return
395456

396457
# 处理 LLM 响应
@@ -401,6 +462,7 @@ async def step(self):
401462
self.final_llm_resp = llm_resp
402463
self.stats.end_time = time.time()
403464
self._transition_state(AgentState.ERROR)
465+
self._resolve_unconsumed_follow_ups()
404466
yield AgentResponse(
405467
type="err",
406468
data=AgentResponseData(
@@ -439,6 +501,7 @@ async def step(self):
439501
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
440502
except Exception as e:
441503
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
504+
self._resolve_unconsumed_follow_ups()
442505

443506
# 返回 LLM 结果
444507
if llm_resp.result_chain:
@@ -583,6 +646,15 @@ async def _handle_function_tools(
583646
tool_call_result_blocks: list[ToolCallMessageSegment] = []
584647
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
585648

649+
def _append_tool_call_result(tool_call_id: str, content: str) -> None:
650+
tool_call_result_blocks.append(
651+
ToolCallMessageSegment(
652+
role="tool",
653+
tool_call_id=tool_call_id,
654+
content=self._merge_follow_up_notice(content),
655+
),
656+
)
657+
586658
# 执行函数调用
587659
for func_tool_name, func_tool_args, func_tool_id in zip(
588660
llm_response.tools_call_name,
@@ -622,12 +694,9 @@ async def _handle_function_tools(
622694

623695
if not func_tool:
624696
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
625-
tool_call_result_blocks.append(
626-
ToolCallMessageSegment(
627-
role="tool",
628-
tool_call_id=func_tool_id,
629-
content=f"error: Tool {func_tool_name} not found.",
630-
),
697+
_append_tool_call_result(
698+
func_tool_id,
699+
f"error: Tool {func_tool_name} not found.",
631700
)
632701
continue
633702

@@ -680,12 +749,9 @@ async def _handle_function_tools(
680749
res = resp
681750
_final_resp = resp
682751
if isinstance(res.content[0], TextContent):
683-
tool_call_result_blocks.append(
684-
ToolCallMessageSegment(
685-
role="tool",
686-
tool_call_id=func_tool_id,
687-
content=res.content[0].text,
688-
),
752+
_append_tool_call_result(
753+
func_tool_id,
754+
res.content[0].text,
689755
)
690756
elif isinstance(res.content[0], ImageContent):
691757
# Cache the image instead of sending directly
@@ -696,15 +762,12 @@ async def _handle_function_tools(
696762
index=0,
697763
mime_type=res.content[0].mimeType or "image/png",
698764
)
699-
tool_call_result_blocks.append(
700-
ToolCallMessageSegment(
701-
role="tool",
702-
tool_call_id=func_tool_id,
703-
content=(
704-
f"Image returned and cached at path='{cached_img.file_path}'. "
705-
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
706-
f"with type='image' and path='{cached_img.file_path}'."
707-
),
765+
_append_tool_call_result(
766+
func_tool_id,
767+
(
768+
f"Image returned and cached at path='{cached_img.file_path}'. "
769+
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
770+
f"with type='image' and path='{cached_img.file_path}'."
708771
),
709772
)
710773
# Yield image info for LLM visibility (will be handled in step())
@@ -714,12 +777,9 @@ async def _handle_function_tools(
714777
elif isinstance(res.content[0], EmbeddedResource):
715778
resource = res.content[0].resource
716779
if isinstance(resource, TextResourceContents):
717-
tool_call_result_blocks.append(
718-
ToolCallMessageSegment(
719-
role="tool",
720-
tool_call_id=func_tool_id,
721-
content=resource.text,
722-
),
780+
_append_tool_call_result(
781+
func_tool_id,
782+
resource.text,
723783
)
724784
elif (
725785
isinstance(resource, BlobResourceContents)
@@ -734,28 +794,22 @@ async def _handle_function_tools(
734794
index=0,
735795
mime_type=resource.mimeType,
736796
)
737-
tool_call_result_blocks.append(
738-
ToolCallMessageSegment(
739-
role="tool",
740-
tool_call_id=func_tool_id,
741-
content=(
742-
f"Image returned and cached at path='{cached_img.file_path}'. "
743-
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
744-
f"with type='image' and path='{cached_img.file_path}'."
745-
),
797+
_append_tool_call_result(
798+
func_tool_id,
799+
(
800+
f"Image returned and cached at path='{cached_img.file_path}'. "
801+
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
802+
f"with type='image' and path='{cached_img.file_path}'."
746803
),
747804
)
748805
# Yield image info for LLM visibility
749806
yield _HandleFunctionToolsResult.from_cached_image(
750807
cached_img
751808
)
752809
else:
753-
tool_call_result_blocks.append(
754-
ToolCallMessageSegment(
755-
role="tool",
756-
tool_call_id=func_tool_id,
757-
content="The tool has returned a data type that is not supported.",
758-
),
810+
_append_tool_call_result(
811+
func_tool_id,
812+
"The tool has returned a data type that is not supported.",
759813
)
760814

761815
elif resp is None:
@@ -767,24 +821,18 @@ async def _handle_function_tools(
767821
)
768822
self._transition_state(AgentState.DONE)
769823
self.stats.end_time = time.time()
770-
tool_call_result_blocks.append(
771-
ToolCallMessageSegment(
772-
role="tool",
773-
tool_call_id=func_tool_id,
774-
content="The tool has no return value, or has sent the result directly to the user.",
775-
),
824+
_append_tool_call_result(
825+
func_tool_id,
826+
"The tool has no return value, or has sent the result directly to the user.",
776827
)
777828
else:
778829
# 不应该出现其他类型
779830
logger.warning(
780831
f"Tool 返回了不支持的类型: {type(resp)}。",
781832
)
782-
tool_call_result_blocks.append(
783-
ToolCallMessageSegment(
784-
role="tool",
785-
tool_call_id=func_tool_id,
786-
content="*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
787-
),
833+
_append_tool_call_result(
834+
func_tool_id,
835+
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
788836
)
789837

790838
try:
@@ -798,12 +846,9 @@ async def _handle_function_tools(
798846
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
799847
except Exception as e:
800848
logger.warning(traceback.format_exc())
801-
tool_call_result_blocks.append(
802-
ToolCallMessageSegment(
803-
role="tool",
804-
tool_call_id=func_tool_id,
805-
content=f"error: {e!s}",
806-
),
849+
_append_tool_call_result(
850+
func_tool_id,
851+
f"error: {e!s}",
807852
)
808853

809854
# yield the last tool call result

0 commit comments

Comments
 (0)