@@ -568,7 +568,9 @@ def fake_save_image(
568568async def test_same_tool_consecutive_results_include_escalating_guidance (
569569 runner , mock_tool_executor , mock_hooks
570570):
571- provider = SequentialToolProvider (["test_tool" ] * 5 )
571+ runner_cls = type (runner )
572+ total_calls = runner_cls .REPEATED_TOOL_NOTICE_L3_THRESHOLD
573+ provider = SequentialToolProvider (["test_tool" ] * total_calls )
572574 tool = FunctionTool (
573575 name = "test_tool" ,
574576 description = "测试工具" ,
@@ -590,16 +592,15 @@ async def test_same_tool_consecutive_results_include_escalating_guidance(
590592 streaming = False ,
591593 )
592594
593- async for _ in runner .step_until_done (6 ):
595+ async for _ in runner .step_until_done (total_calls + 1 ):
594596 pass
595597
596598 tool_messages = [
597599 m for m in runner .run_context .messages if getattr (m , "role" , None ) == "tool"
598600 ]
599- assert len (tool_messages ) == 5
601+ assert len (tool_messages ) == total_calls
600602
601603 tool_contents = [str (message .content ) for message in tool_messages ]
602- runner_cls = type (runner )
603604 level_1_notice = runner_cls .REPEATED_TOOL_NOTICE_L1_TEMPLATE .format (
604605 tool_name = "test_tool" ,
605606 streak = runner_cls .REPEATED_TOOL_NOTICE_L1_THRESHOLD ,
@@ -613,19 +614,33 @@ async def test_same_tool_consecutive_results_include_escalating_guidance(
613614 streak = runner_cls .REPEATED_TOOL_NOTICE_L3_THRESHOLD ,
614615 )
615616
616- assert level_1_notice not in tool_contents [0 ]
617- assert level_2_notice not in tool_contents [0 ]
618- assert level_1_notice in tool_contents [1 ]
619- assert level_2_notice in tool_contents [2 ]
620- assert level_3_notice in tool_contents [4 ]
617+ for streak , content in enumerate (tool_contents , start = 1 ):
618+ if streak < runner_cls .REPEATED_TOOL_NOTICE_L1_THRESHOLD :
619+ assert level_1_notice not in content
620+ assert level_2_notice not in content
621+ assert level_3_notice not in content
622+ elif streak < runner_cls .REPEATED_TOOL_NOTICE_L2_THRESHOLD :
623+ assert level_1_notice in content
624+ assert level_2_notice not in content
625+ assert level_3_notice not in content
626+ elif streak < runner_cls .REPEATED_TOOL_NOTICE_L3_THRESHOLD :
627+ assert level_1_notice not in content
628+ assert level_2_notice in content
629+ assert level_3_notice not in content
630+ else :
631+ assert level_1_notice not in content
632+ assert level_2_notice not in content
633+ assert level_3_notice in content
621634
622635
623636@pytest .mark .asyncio
624637async def test_same_tool_streak_resets_after_switching_tools (
625638 runner , mock_tool_executor , mock_hooks
626639):
640+ runner_cls = type (runner )
641+ repeated_after_reset = runner_cls .REPEATED_TOOL_NOTICE_L1_THRESHOLD
627642 provider = SequentialToolProvider (
628- ["test_tool" , "other_tool" , "test_tool" , "test_tool" ]
643+ ["test_tool" , "other_tool" , * ([ "test_tool" ] * repeated_after_reset ) ]
629644 )
630645 tool_a = FunctionTool (
631646 name = "test_tool" ,
@@ -654,16 +669,15 @@ async def test_same_tool_streak_resets_after_switching_tools(
654669 streaming = False ,
655670 )
656671
657- async for _ in runner .step_until_done (5 ):
672+ async for _ in runner .step_until_done (repeated_after_reset + 3 ):
658673 pass
659674
660675 tool_messages = [
661676 m for m in runner .run_context .messages if getattr (m , "role" , None ) == "tool"
662677 ]
663- assert len (tool_messages ) == 4
678+ assert len (tool_messages ) == repeated_after_reset + 2
664679
665680 tool_contents = [str (message .content ) for message in tool_messages ]
666- runner_cls = type (runner )
667681 level_1_notice = runner_cls .REPEATED_TOOL_NOTICE_L1_TEMPLATE .format (
668682 tool_name = "test_tool" ,
669683 streak = runner_cls .REPEATED_TOOL_NOTICE_L1_THRESHOLD ,
@@ -675,9 +689,20 @@ async def test_same_tool_streak_resets_after_switching_tools(
675689
676690 assert level_1_notice not in tool_contents [0 ]
677691 assert level_1_notice not in tool_contents [1 ]
678- assert level_1_notice not in tool_contents [2 ]
679- assert level_2_notice not in tool_contents [2 ]
680- assert level_1_notice in tool_contents [3 ]
692+ assert level_2_notice not in tool_contents [0 ]
693+ assert level_2_notice not in tool_contents [1 ]
694+
695+ repeated_contents = tool_contents [2 :]
696+ for streak_after_reset , content in enumerate (repeated_contents , start = 1 ):
697+ if streak_after_reset < runner_cls .REPEATED_TOOL_NOTICE_L1_THRESHOLD :
698+ assert level_1_notice not in content
699+ assert level_2_notice not in content
700+ elif streak_after_reset < runner_cls .REPEATED_TOOL_NOTICE_L2_THRESHOLD :
701+ assert level_1_notice in content
702+ assert level_2_notice not in content
703+ else :
704+ assert level_1_notice not in content
705+ assert level_2_notice in content
681706
682707
683708@pytest .mark .asyncio
0 commit comments