Skip to content

Commit 8c6c00a

Browse files
committed
fix: update tool result assertions to reflect dynamic threshold values
1 parent 0ce5fde commit 8c6c00a

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

tests/test_tool_loop_agent_runner.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,9 @@ def fake_save_image(
568568
async def test_same_tool_consecutive_results_include_escalating_guidance(
569569
runner, mock_tool_executor, mock_hooks
570570
):
571-
provider = SequentialToolProvider(["test_tool"] * 5)
571+
runner_cls = type(runner)
572+
total_calls = runner_cls.REPEATED_TOOL_NOTICE_L3_THRESHOLD
573+
provider = SequentialToolProvider(["test_tool"] * total_calls)
572574
tool = FunctionTool(
573575
name="test_tool",
574576
description="测试工具",
@@ -590,16 +592,15 @@ async def test_same_tool_consecutive_results_include_escalating_guidance(
590592
streaming=False,
591593
)
592594

593-
async for _ in runner.step_until_done(6):
595+
async for _ in runner.step_until_done(total_calls + 1):
594596
pass
595597

596598
tool_messages = [
597599
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
598600
]
599-
assert len(tool_messages) == 5
601+
assert len(tool_messages) == total_calls
600602

601603
tool_contents = [str(message.content) for message in tool_messages]
602-
runner_cls = type(runner)
603604
level_1_notice = runner_cls.REPEATED_TOOL_NOTICE_L1_TEMPLATE.format(
604605
tool_name="test_tool",
605606
streak=runner_cls.REPEATED_TOOL_NOTICE_L1_THRESHOLD,
@@ -613,19 +614,33 @@ async def test_same_tool_consecutive_results_include_escalating_guidance(
613614
streak=runner_cls.REPEATED_TOOL_NOTICE_L3_THRESHOLD,
614615
)
615616

616-
assert level_1_notice not in tool_contents[0]
617-
assert level_2_notice not in tool_contents[0]
618-
assert level_1_notice in tool_contents[1]
619-
assert level_2_notice in tool_contents[2]
620-
assert level_3_notice in tool_contents[4]
617+
for streak, content in enumerate(tool_contents, start=1):
618+
if streak < runner_cls.REPEATED_TOOL_NOTICE_L1_THRESHOLD:
619+
assert level_1_notice not in content
620+
assert level_2_notice not in content
621+
assert level_3_notice not in content
622+
elif streak < runner_cls.REPEATED_TOOL_NOTICE_L2_THRESHOLD:
623+
assert level_1_notice in content
624+
assert level_2_notice not in content
625+
assert level_3_notice not in content
626+
elif streak < runner_cls.REPEATED_TOOL_NOTICE_L3_THRESHOLD:
627+
assert level_1_notice not in content
628+
assert level_2_notice in content
629+
assert level_3_notice not in content
630+
else:
631+
assert level_1_notice not in content
632+
assert level_2_notice not in content
633+
assert level_3_notice in content
621634

622635

623636
@pytest.mark.asyncio
624637
async def test_same_tool_streak_resets_after_switching_tools(
625638
runner, mock_tool_executor, mock_hooks
626639
):
640+
runner_cls = type(runner)
641+
repeated_after_reset = runner_cls.REPEATED_TOOL_NOTICE_L1_THRESHOLD
627642
provider = SequentialToolProvider(
628-
["test_tool", "other_tool", "test_tool", "test_tool"]
643+
["test_tool", "other_tool", *(["test_tool"] * repeated_after_reset)]
629644
)
630645
tool_a = FunctionTool(
631646
name="test_tool",
@@ -654,16 +669,15 @@ async def test_same_tool_streak_resets_after_switching_tools(
654669
streaming=False,
655670
)
656671

657-
async for _ in runner.step_until_done(5):
672+
async for _ in runner.step_until_done(repeated_after_reset + 3):
658673
pass
659674

660675
tool_messages = [
661676
m for m in runner.run_context.messages if getattr(m, "role", None) == "tool"
662677
]
663-
assert len(tool_messages) == 4
678+
assert len(tool_messages) == repeated_after_reset + 2
664679

665680
tool_contents = [str(message.content) for message in tool_messages]
666-
runner_cls = type(runner)
667681
level_1_notice = runner_cls.REPEATED_TOOL_NOTICE_L1_TEMPLATE.format(
668682
tool_name="test_tool",
669683
streak=runner_cls.REPEATED_TOOL_NOTICE_L1_THRESHOLD,
@@ -675,9 +689,20 @@ async def test_same_tool_streak_resets_after_switching_tools(
675689

676690
assert level_1_notice not in tool_contents[0]
677691
assert level_1_notice not in tool_contents[1]
678-
assert level_1_notice not in tool_contents[2]
679-
assert level_2_notice not in tool_contents[2]
680-
assert level_1_notice in tool_contents[3]
692+
assert level_2_notice not in tool_contents[0]
693+
assert level_2_notice not in tool_contents[1]
694+
695+
repeated_contents = tool_contents[2:]
696+
for streak_after_reset, content in enumerate(repeated_contents, start=1):
697+
if streak_after_reset < runner_cls.REPEATED_TOOL_NOTICE_L1_THRESHOLD:
698+
assert level_1_notice not in content
699+
assert level_2_notice not in content
700+
elif streak_after_reset < runner_cls.REPEATED_TOOL_NOTICE_L2_THRESHOLD:
701+
assert level_1_notice in content
702+
assert level_2_notice not in content
703+
else:
704+
assert level_1_notice not in content
705+
assert level_2_notice in content
681706

682707

683708
@pytest.mark.asyncio

0 commit comments

Comments
 (0)