Skip to content

Commit 0927155

Browse files
Emaan Khanemaan-c
authored andcommitted
feat: add streaming to direct tool calls
1 parent 50b2c79 commit 0927155

3 files changed

Lines changed: 655 additions & 84 deletions

File tree

src/strands/tools/_caller.py

Lines changed: 220 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
"""Support direct tool calls through agent.
1+
"""Direct tool call support.
2+
3+
This module provides the _DirectToolCall and _ToolCaller classes that enable direct tool invocation through the
4+
agent.tool interface, including synchronous execution and streaming methods.
25
36
Example:
47
```
@@ -7,10 +10,15 @@
710
```
811
"""
912

13+
import asyncio
14+
import contextvars
1015
import json
16+
import logging
17+
import queue
1118
import random
1219
import weakref
13-
from collections.abc import Callable
20+
from collections.abc import AsyncIterator, Iterator
21+
from concurrent.futures import ThreadPoolExecutor
1422
from typing import TYPE_CHECKING, Any
1523

1624
from .._async import run_async
@@ -24,19 +32,34 @@
2432
from ..agent import Agent
2533
from ..experimental.bidi.agent import BidiAgent
2634

35+
logger = logging.getLogger(__name__)
2736

28-
class _ToolCaller:
29-
"""Call tool as a function."""
37+
# Sentinel to signal end of stream
38+
_STREAM_END = object()
3039

31-
def __init__(self, agent: "Agent | BidiAgent") -> None:
32-
"""Initialize instance.
40+
41+
class _DirectToolCall:
42+
"""Callable wrapper for a single tool that provides streaming methods.
43+
44+
This class enables three execution modes for direct tool calls:
45+
1. Synchronous: ``result = agent.tool.my_tool(x=5)``
46+
2. Sync streaming: ``for event in agent.tool.my_tool.stream(x=5)``
47+
3. Async streaming: ``async for event in agent.tool.my_tool.stream_async(x=5)``
48+
49+
Streaming methods do not acquire the invocation lock, do not record to message
50+
history, and do not apply conversation management. They are designed for
51+
observability and real-time progress monitoring.
52+
"""
53+
54+
def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None:
55+
"""Initialize direct tool call.
3356
3457
Args:
35-
agent: Agent reference that will accept tool results.
58+
agent: Agent reference that owns the tools.
59+
tool_name: Name of the tool to execute.
3660
"""
37-
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
38-
# agent tools and thus break their execution.
3961
self._agent_ref = weakref.ref(agent)
62+
self._tool_name = tool_name
4063

4164
@property
4265
def _agent(self) -> "Agent | BidiAgent":
@@ -46,104 +69,181 @@ def _agent(self) -> "Agent | BidiAgent":
4669
raise ReferenceError("Agent has been garbage collected")
4770
return agent
4871

49-
def __getattr__(self, name: str) -> Callable[..., Any]:
50-
"""Call tool as a function.
72+
def _prepare_tool_use(self, **kwargs: Any) -> tuple[ToolUse, list[ToolResult], dict[str, Any]]:
73+
"""Prepare tool use request, results list, and invocation state.
74+
75+
Args:
76+
**kwargs: Tool parameters.
77+
78+
Returns:
79+
Tuple of (tool_use, tool_results, invocation_state).
80+
81+
Raises:
82+
AttributeError: If tool doesn't exist.
83+
"""
84+
normalized_name = self._find_normalized_tool_name(self._tool_name)
85+
tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}"
86+
tool_use: ToolUse = {
87+
"toolUseId": tool_id,
88+
"name": normalized_name,
89+
"input": kwargs.copy(),
90+
}
91+
return tool_use, [], kwargs
5192

52-
This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
93+
def __call__(
94+
self,
95+
user_message_override: str | None = None,
96+
record_direct_tool_call: bool | None = None,
97+
**kwargs: Any,
98+
) -> ToolResult:
99+
"""Synchronous tool execution (existing behavior - backward compatible).
100+
101+
This method enables the method-style interface (e.g., ``agent.tool.tool_name(param="value")``).
53102
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').
54103
55104
Args:
56-
name: The name of the attribute (tool) being accessed.
105+
user_message_override: Optional custom message to record.
106+
record_direct_tool_call: Whether to record in message history.
107+
**kwargs: Tool parameters.
57108
58109
Returns:
59-
A function that when called will execute the named tool.
110+
ToolResult from execution.
60111
61112
Raises:
62-
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
113+
AttributeError: If tool doesn't exist.
114+
RuntimeError: If called during interrupt.
115+
ConcurrencyException: If invocation lock cannot be acquired.
63116
"""
117+
if self._agent._interrupt_state.activated:
118+
raise RuntimeError("cannot directly call tool during interrupt")
119+
120+
if record_direct_tool_call is not None:
121+
should_record_direct_tool_call = record_direct_tool_call
122+
else:
123+
should_record_direct_tool_call = self._agent.record_direct_tool_call
124+
125+
should_lock = should_record_direct_tool_call
64126

65-
def caller(
66-
user_message_override: str | None = None,
67-
record_direct_tool_call: bool | None = None,
68-
**kwargs: Any,
69-
) -> Any:
70-
"""Call a tool directly by name.
71-
72-
Args:
73-
user_message_override: Optional custom message to record instead of default
74-
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
75-
attribute if provided.
76-
**kwargs: Keyword arguments to pass to the tool.
77-
78-
Returns:
79-
The result returned by the tool.
80-
81-
Raises:
82-
AttributeError: If the tool doesn't exist.
83-
"""
84-
if self._agent._interrupt_state.activated:
85-
raise RuntimeError("cannot directly call tool during interrupt")
86-
87-
if record_direct_tool_call is not None:
88-
should_record_direct_tool_call = record_direct_tool_call
89-
else:
90-
should_record_direct_tool_call = self._agent.record_direct_tool_call
91-
92-
should_lock = should_record_direct_tool_call
93-
94-
from ..agent import Agent # Locally imported to avoid circular reference
95-
96-
acquired_lock = (
97-
should_lock
98-
and isinstance(self._agent, Agent)
99-
and self._agent._invocation_lock.acquire_lock(blocking=False)
127+
from ..agent import Agent # Locally imported to avoid circular reference
128+
129+
acquired_lock = (
130+
should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False)
131+
)
132+
if should_lock and not acquired_lock:
133+
raise ConcurrencyException(
134+
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
135+
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
100136
)
101-
if should_lock and not acquired_lock:
102-
raise ConcurrencyException(
103-
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
104-
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
105-
)
106137

107-
try:
108-
normalized_name = self._find_normalized_tool_name(name)
138+
try:
139+
tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs)
109140

110-
# Create unique tool ID and set up the tool request
111-
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
112-
tool_use: ToolUse = {
113-
"toolUseId": tool_id,
114-
"name": normalized_name,
115-
"input": kwargs.copy(),
116-
}
117-
tool_results: list[ToolResult] = []
118-
invocation_state = kwargs
141+
async def acall() -> ToolResult:
142+
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
143+
if isinstance(event, ToolInterruptEvent):
144+
self._agent._interrupt_state.deactivate()
145+
raise RuntimeError("cannot raise interrupt in direct tool call")
119146

120-
async def acall() -> ToolResult:
121-
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
122-
if isinstance(event, ToolInterruptEvent):
123-
self._agent._interrupt_state.deactivate()
124-
raise RuntimeError("cannot raise interrupt in direct tool call")
147+
tool_result = tool_results[0]
125148

126-
tool_result = tool_results[0]
149+
if should_record_direct_tool_call:
150+
await self._record_tool_execution(tool_use, tool_result, user_message_override)
127151

128-
if should_record_direct_tool_call:
129-
# Create a record of this tool execution in the message history
130-
await self._record_tool_execution(tool_use, tool_result, user_message_override)
152+
return tool_result
131153

132-
return tool_result
154+
tool_result = run_async(acall)
133155

134-
tool_result = run_async(acall)
156+
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
157+
if isinstance(self._agent, Agent):
158+
self._agent.conversation_manager.apply_management(self._agent)
135159

136-
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
137-
if isinstance(self._agent, Agent):
138-
self._agent.conversation_manager.apply_management(self._agent)
160+
return tool_result
139161

140-
return tool_result
162+
finally:
163+
if acquired_lock and isinstance(self._agent, Agent):
164+
self._agent._invocation_lock.release()
165+
166+
def stream(self, **kwargs: Any) -> Iterator[Any]:
167+
"""Synchronous streaming of tool execution events.
168+
169+
Bridges async-to-sync streaming using a background thread and queue, yielding
170+
events in real-time as they are produced by the tool.
171+
172+
This method does not acquire the invocation lock, does not record to message
173+
history, and does not apply conversation management.
141174
175+
Args:
176+
**kwargs: Tool parameters.
177+
178+
Yields:
179+
Tool execution events in real-time.
180+
181+
Raises:
182+
AttributeError: If tool doesn't exist.
183+
RuntimeError: If called during interrupt.
184+
"""
185+
# Fast-fail before spinning up a thread; stream_async also checks but this avoids unnecessary overhead
186+
if self._agent._interrupt_state.activated:
187+
raise RuntimeError("cannot directly call tool during interrupt")
188+
189+
event_queue: queue.Queue[Any] = queue.Queue()
190+
191+
async def _produce() -> None:
192+
try:
193+
async for event in self.stream_async(**kwargs):
194+
event_queue.put(event)
195+
except BaseException:
196+
# Re-raise to propagate via future.result(); the sentinel must still be placed
197+
# on the queue so the main thread unblocks before checking the future
198+
raise
142199
finally:
143-
if acquired_lock and isinstance(self._agent, Agent):
144-
self._agent._invocation_lock.release()
200+
event_queue.put(_STREAM_END)
201+
202+
context = contextvars.copy_context()
203+
with ThreadPoolExecutor(max_workers=1) as executor:
204+
future = executor.submit(context.run, asyncio.run, _produce())
205+
206+
while True:
207+
item = event_queue.get()
208+
if item is _STREAM_END:
209+
break
210+
yield item
145211

146-
return caller
212+
# Propagates any exception from the producer thread
213+
future.result()
214+
215+
async def stream_async(self, **kwargs: Any) -> AsyncIterator[Any]:
216+
"""Asynchronous streaming of tool execution events.
217+
218+
Yields events directly from tool execution without recording to message
219+
history. Designed for observability and real-time progress monitoring.
220+
221+
This method does not acquire the invocation lock, does not record to message
222+
history, and does not apply conversation management. It can be used concurrently
223+
with agent invocations.
224+
225+
Args:
226+
**kwargs: Tool parameters.
227+
228+
Yields:
229+
Tool execution events from ToolExecutor._stream().
230+
231+
Raises:
232+
AttributeError: If tool doesn't exist.
233+
RuntimeError: If called during interrupt.
234+
"""
235+
if self._agent._interrupt_state.activated:
236+
raise RuntimeError("cannot directly call tool during interrupt")
237+
238+
tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs)
239+
240+
logger.debug("tool_name=<%s>, streaming=<True> | executing tool stream", tool_use["name"])
241+
242+
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
243+
if isinstance(event, ToolInterruptEvent):
244+
self._agent._interrupt_state.deactivate()
245+
raise RuntimeError("cannot raise interrupt in direct tool call")
246+
yield event
147247

148248
def _find_normalized_tool_name(self, name: str) -> str:
149249
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
@@ -246,3 +346,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di
246346

247347
properties = tool_spec["inputSchema"]["json"]["properties"]
248348
return {k: v for k, v in input_params.items() if k in properties}
349+
350+
351+
class _ToolCaller:
352+
"""Call tool as a function."""
353+
354+
def __init__(self, agent: "Agent | BidiAgent") -> None:
355+
"""Initialize instance.
356+
357+
Args:
358+
agent: Agent reference that will accept tool results.
359+
"""
360+
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
361+
# agent tools and thus break their execution.
362+
self._agent_ref = weakref.ref(agent)
363+
364+
@property
365+
def _agent(self) -> "Agent | BidiAgent":
366+
"""Return the agent, raising ReferenceError if it has been garbage collected."""
367+
agent = self._agent_ref()
368+
if agent is None:
369+
raise ReferenceError("Agent has been garbage collected")
370+
return agent
371+
372+
def __getattr__(self, name: str) -> _DirectToolCall:
373+
"""Return direct tool call with streaming methods.
374+
375+
This method enables the tool calling interface by returning a callable
376+
object that provides both synchronous execution and streaming methods.
377+
378+
Args:
379+
name: Tool name.
380+
381+
Returns:
382+
Direct tool call instance.
383+
"""
384+
return _DirectToolCall(self._agent, name)

0 commit comments

Comments
 (0)