Skip to content

Commit d12edb4

Browse files
committed
feat: add on_tool_progress hook for mid-execution tool progress updates
1 parent eca794c commit d12edb4

5 files changed

Lines changed: 410 additions & 0 deletions

File tree

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Example: tool progress via on_tool_progress hooks.
2+
3+
Demonstrates how tools can emit intermediate progress updates using
4+
await ctx.send_progress(data), consumed via RunHooks.on_tool_progress.
5+
"""
6+
7+
import asyncio
8+
9+
from agents import Agent, RunHooks, Runner, function_tool
10+
from agents.tool import Tool
11+
from agents.tool_context import ToolContext
12+
13+
14+
@function_tool
15+
async def analyze_data(ctx: ToolContext, query: str) -> str:
16+
"""Simulate a long-running data analysis task with progress updates."""
17+
await ctx.send_progress({"status": "starting", "query": query})
18+
await asyncio.sleep(1)
19+
20+
await ctx.send_progress({"status": "fetching_data", "progress": 0.25})
21+
await asyncio.sleep(1)
22+
23+
await ctx.send_progress({"status": "processing", "progress": 0.5})
24+
await asyncio.sleep(1)
25+
26+
await ctx.send_progress({"status": "finalizing", "progress": 1.0})
27+
await asyncio.sleep(0.5)
28+
29+
return f"Analysis complete for '{query}': found 42 results with 95% confidence."
30+
31+
32+
@function_tool
33+
async def quick_lookup(ctx: ToolContext, term: str) -> str:
34+
"""A faster tool that also emits progress."""
35+
await ctx.send_progress({"status": "searching", "term": term})
36+
await asyncio.sleep(0.5)
37+
return f"Found definition for '{term}': a common search term."
38+
39+
40+
class ProgressHooks(RunHooks):
41+
async def on_tool_progress(self, ctx, agent, tool: Tool, data):
42+
print(f" [progress] {tool.name}: {data}")
43+
44+
45+
async def main():
46+
agent = Agent(
47+
name="Analyst",
48+
instructions=(
49+
"You are a data analyst. Use the analyze_data tool for complex queries "
50+
"and quick_lookup for simple lookups. Always use the tools when asked."
51+
),
52+
tools=[analyze_data, quick_lookup],
53+
)
54+
55+
hooks = ProgressHooks()
56+
57+
print("Interactive tool progress example (hooks-based).")
58+
print("Type a message to chat, or 'quit' to exit.\n")
59+
60+
while True:
61+
user_input = input("You: ").strip()
62+
if not user_input or user_input.lower() == "quit":
63+
print("Goodbye!")
64+
break
65+
66+
result = Runner.run_streamed(agent, input=user_input, hooks=hooks)
67+
async for event in result.stream_events():
68+
if event.type == "raw_response_event":
69+
data = event.data
70+
if getattr(data, "type", None) == "response.output_text.delta":
71+
print(data.delta, end="", flush=True)
72+
elif event.type == "agent_updated_stream_event":
73+
print(f"Agent: {event.new_agent.name}")
74+
elif event.type == "run_item_stream_event":
75+
if event.item.type == "tool_call_item":
76+
print(f"\n-- Tool called: {getattr(event.item.raw_item, 'name', '?')}")
77+
elif event.item.type == "tool_call_output_item":
78+
print(f"\n-- Tool output: {event.item.output}")
79+
elif event.item.type == "message_output_item":
80+
print() # newline after streamed tokens
81+
82+
print()
83+
84+
85+
if __name__ == "__main__":
86+
asyncio.run(main())

src/agents/lifecycle.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ async def on_tool_end(
9898
"""
9999
pass
100100

101+
async def on_tool_progress(
102+
self,
103+
context: RunContextWrapper[TContext],
104+
agent: TAgent,
105+
tool: Tool,
106+
data: Any,
107+
) -> None:
108+
"""Called when a tool emits a progress update via ``send_progress()``.
109+
110+
Unlike ``on_tool_start``/``on_tool_end`` which fire at lifecycle boundaries,
111+
this fires from inside the tool body at arbitrary points. For function-tool
112+
invocations, ``context`` is typically a ``ToolContext``.
113+
"""
114+
pass
115+
101116

102117
class AgentHooksBase(Generic[TContext, TAgent]):
103118
"""A class that receives callbacks on various lifecycle events for a specific agent. You can
@@ -172,6 +187,21 @@ async def on_tool_end(
172187
"""
173188
pass
174189

190+
async def on_tool_progress(
191+
self,
192+
context: RunContextWrapper[TContext],
193+
agent: TAgent,
194+
tool: Tool,
195+
data: Any,
196+
) -> None:
197+
"""Called when a tool emits a progress update via ``send_progress()``.
198+
199+
Unlike ``on_tool_start``/``on_tool_end`` which fire at lifecycle boundaries,
200+
this fires from inside the tool body at arbitrary points. For function-tool
201+
invocations, ``context`` is typically a ``ToolContext``.
202+
"""
203+
pass
204+
175205
async def on_llm_start(
176206
self,
177207
context: RunContextWrapper[TContext],

src/agents/run_internal/tool_execution.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,21 @@ async def _run_single_tool(
15821582
run_config=self.config,
15831583
)
15841584
agent_hooks = self.public_agent.hooks
1585+
1586+
async def _send_progress(data: Any) -> None:
1587+
await asyncio.gather(
1588+
self.hooks.on_tool_progress(tool_context, self.public_agent, func_tool, data),
1589+
(
1590+
agent_hooks.on_tool_progress(
1591+
tool_context, self.public_agent, func_tool, data
1592+
)
1593+
if agent_hooks
1594+
else _coro.noop_coroutine()
1595+
),
1596+
)
1597+
1598+
tool_context.set_progress_fn(_send_progress)
1599+
15851600
if self.config.trace_include_sensitive_data:
15861601
span_fn.span_data.input = tool_call.arguments
15871602

src/agents/tool_context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Awaitable, Callable
34
from dataclasses import dataclass, field, fields
45
from typing import TYPE_CHECKING, Any, cast
56

@@ -104,11 +105,25 @@ def __init__(
104105
self.agent = agent
105106
self.run_config = run_config
106107

108+
_progress_fn: Callable[[Any], Awaitable[None]] | None = None
109+
107110
@property
108111
def qualified_tool_name(self) -> str:
109112
"""Return the tool name qualified by namespace when available."""
110113
return tool_trace_name(self.tool_name, self.tool_namespace) or self.tool_name
111114

115+
async def send_progress(self, data: Any) -> None:
116+
"""Emit a progress update, firing ``on_tool_progress`` hooks.
117+
118+
No-op if no progress handler has been set by the framework.
119+
"""
120+
if self._progress_fn is not None:
121+
await self._progress_fn(data)
122+
123+
def set_progress_fn(self, fn: Callable[[Any], Awaitable[None]]) -> None:
124+
"""Set the progress handler. Called by the framework during tool invocation."""
125+
self._progress_fn = fn
126+
112127
@classmethod
113128
def from_agent_context(
114129
cls,

0 commit comments

Comments
 (0)