@@ -193,6 +193,32 @@ async def text_chat(self, **kwargs) -> LLMResponse:
193193 )
194194
195195
196+ class SequentialToolProvider (MockProvider ):
197+ def __init__ (self , tool_sequence : list [str ]):
198+ super ().__init__ ()
199+ self .tool_sequence = tool_sequence
200+
201+ async def text_chat (self , ** kwargs ) -> LLMResponse :
202+ self .call_count += 1
203+ func_tool = kwargs .get ("func_tool" )
204+ if func_tool is None or self .call_count > len (self .tool_sequence ):
205+ return LLMResponse (
206+ role = "assistant" ,
207+ completion_text = "这是我的最终回答" ,
208+ usage = TokenUsage (input_other = 10 , output = 5 ),
209+ )
210+
211+ tool_name = self .tool_sequence [self .call_count - 1 ]
212+ return LLMResponse (
213+ role = "assistant" ,
214+ completion_text = "" ,
215+ tools_call_name = [tool_name ],
216+ tools_call_args = [{"query" : f"step-{ self .call_count } " }],
217+ tools_call_ids = [f"call_{ self .call_count } " ],
218+ usage = TokenUsage (input_other = 10 , output = 5 ),
219+ )
220+
221+
196222class MockHandoffProvider (MockToolCallProvider ):
197223 def __init__ (self , handoff_tool_name : str ):
198224 super ().__init__ (handoff_tool_name , {"input" : "delegate this task" })
@@ -538,6 +564,101 @@ def fake_save_image(
538564 ]
539565
540566
567+ @pytest .mark .asyncio
568+ async def test_same_tool_consecutive_results_include_escalating_guidance (
569+ runner , mock_tool_executor , mock_hooks
570+ ):
571+ provider = SequentialToolProvider (["test_tool" ] * 5 )
572+ tool = FunctionTool (
573+ name = "test_tool" ,
574+ description = "测试工具" ,
575+ parameters = {"type" : "object" , "properties" : {"query" : {"type" : "string" }}},
576+ handler = AsyncMock (),
577+ )
578+ request = ProviderRequest (
579+ prompt = "请连续执行工具" ,
580+ func_tool = ToolSet (tools = [tool ]),
581+ contexts = [],
582+ )
583+
584+ await runner .reset (
585+ provider = provider ,
586+ request = request ,
587+ run_context = ContextWrapper (context = None ),
588+ tool_executor = mock_tool_executor ,
589+ agent_hooks = mock_hooks ,
590+ streaming = False ,
591+ )
592+
593+ async for _ in runner .step_until_done (6 ):
594+ pass
595+
596+ tool_messages = [
597+ m for m in runner .run_context .messages if getattr (m , "role" , None ) == "tool"
598+ ]
599+ assert len (tool_messages ) == 5
600+
601+ tool_contents = [str (message .content ) for message in tool_messages ]
602+ assert "same tool" not in tool_contents [0 ]
603+ assert "By the way" in tool_contents [1 ]
604+ assert "2 times consecutively" in tool_contents [1 ]
605+ assert "Important" in tool_contents [2 ]
606+ assert "3 times consecutively" in tool_contents [2 ]
607+ assert "Important" in tool_contents [4 ]
608+ assert "5 times consecutively" in tool_contents [4 ]
609+ assert "very high" in tool_contents [4 ]
610+
611+
612+ @pytest .mark .asyncio
613+ async def test_same_tool_streak_resets_after_switching_tools (
614+ runner , mock_tool_executor , mock_hooks
615+ ):
616+ provider = SequentialToolProvider (
617+ ["test_tool" , "other_tool" , "test_tool" , "test_tool" ]
618+ )
619+ tool_a = FunctionTool (
620+ name = "test_tool" ,
621+ description = "测试工具 A" ,
622+ parameters = {"type" : "object" , "properties" : {"query" : {"type" : "string" }}},
623+ handler = AsyncMock (),
624+ )
625+ tool_b = FunctionTool (
626+ name = "other_tool" ,
627+ description = "测试工具 B" ,
628+ parameters = {"type" : "object" , "properties" : {"query" : {"type" : "string" }}},
629+ handler = AsyncMock (),
630+ )
631+ request = ProviderRequest (
632+ prompt = "切换工具后再重复" ,
633+ func_tool = ToolSet (tools = [tool_a , tool_b ]),
634+ contexts = [],
635+ )
636+
637+ await runner .reset (
638+ provider = provider ,
639+ request = request ,
640+ run_context = ContextWrapper (context = None ),
641+ tool_executor = mock_tool_executor ,
642+ agent_hooks = mock_hooks ,
643+ streaming = False ,
644+ )
645+
646+ async for _ in runner .step_until_done (5 ):
647+ pass
648+
649+ tool_messages = [
650+ m for m in runner .run_context .messages if getattr (m , "role" , None ) == "tool"
651+ ]
652+ assert len (tool_messages ) == 4
653+
654+ tool_contents = [str (message .content ) for message in tool_messages ]
655+ assert "same tool" not in tool_contents [0 ]
656+ assert "same tool" not in tool_contents [1 ]
657+ assert "same tool" not in tool_contents [2 ]
658+ assert "By the way" in tool_contents [3 ]
659+ assert "`test_tool` 2 times consecutively" in tool_contents [3 ]
660+
661+
541662@pytest .mark .asyncio
542663async def test_fallback_provider_used_when_primary_raises (
543664 runner , provider_request , mock_tool_executor , mock_hooks
0 commit comments