Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tests/test_run_hooks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections import defaultdict
from typing import Any, cast

Expand All @@ -16,6 +17,8 @@
from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_function_tool_call,
get_handoff_tool_call,
get_text_message,
)

Expand Down Expand Up @@ -318,3 +321,68 @@ async def test_run_hooks_receives_turn_input_streamed():
turn_input = hooks.captured_turn_inputs[0]
assert len(turn_input) == 1
assert turn_input[0]["content"] == "streamed input"


@pytest.mark.asyncio
async def test_run_hooks_count_tool_and_handoff_invocations():
hooks = RunHooksForTests()
model = FakeModel()

agent_1 = Agent(name="test_1", model=model)
agent_2 = Agent(
name="test_2",
model=model,
handoffs=[agent_1],
tools=[get_function_tool("some_function", "result")],
)

model.add_multiple_turn_outputs(
[
[get_function_tool_call("some_function", json.dumps({"a": "b"}))],
[get_text_message("a_message"), get_handoff_tool_call(agent_1)],
[get_text_message("done")],
]
)
await Runner.run(agent_2, input="user_message", hooks=hooks)

assert hooks.events["on_tool_start"] == 1
assert hooks.events["on_tool_end"] == 1
assert hooks.events["on_handoff"] == 1
assert hooks.events["on_agent_start"] == 2
assert hooks.events["on_agent_end"] == 1
assert len(hooks.tool_context_ids) == 1


@pytest.mark.asyncio
async def test_streamed_run_hooks_count_tool_and_handoff_invocations():
hooks = RunHooksForTests()
model = FakeModel()

agent_1 = Agent(name="test_1", model=model)
agent_2 = Agent(
name="test_2",
model=model,
handoffs=[agent_1],
tools=[get_function_tool("some_function", "result")],
)

model.add_multiple_turn_outputs(
[
[
get_function_tool_call("some_function", json.dumps({"a": "b"})),
get_function_tool_call("some_function", json.dumps({"a": "b"})),
],
[get_text_message("a_message"), get_handoff_tool_call(agent_1)],
[get_text_message("done")],
]
)
stream = Runner.run_streamed(agent_2, input="user_message", hooks=hooks)
async for _ in stream.stream_events():
pass

assert hooks.events["on_tool_start"] == 2
assert hooks.events["on_tool_end"] == 2
assert hooks.events["on_handoff"] == 1
assert hooks.events["on_agent_start"] == 2
assert hooks.events["on_agent_end"] == 1
assert len(hooks.tool_context_ids) == 2