Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 2 additions & 111 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from ..session.session_manager import SessionManager
from ..telemetry.metrics import EventLoopMetrics
from ..telemetry.tracer import get_tracer, serialize
from ..tools.caller import ToolCaller
from ..tools.executors import ConcurrentToolExecutor
from ..tools.executors._executor import ToolExecutor
from ..tools.registry import ToolRegistry
Expand Down Expand Up @@ -102,116 +103,6 @@ class Agent:
6. Produces a final response
"""

class ToolCaller:
Comment thread
mehtarac marked this conversation as resolved.
"""Call tool as a function."""

def __init__(self, agent: "Agent") -> None:
"""Initialize instance.

Args:
agent: Agent reference that will accept tool results.
"""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent = agent

def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').

Args:
name: The name of the attribute (tool) being accessed.

Returns:
A function that when called will execute the named tool.

Raises:
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
"""

def caller(
user_message_override: Optional[str] = None,
record_direct_tool_call: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Call a tool directly by name.

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
attribute if provided.
**kwargs: Keyword arguments to pass to the tool.

Returns:
The result returned by the tool.

Raises:
AttributeError: If the tool doesn't exist.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

normalized_name = self._find_normalized_tool_name(name)

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs

async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")

return tool_results[0]

tool_result = run_async(acall)

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)

# Apply window management
self._agent.conversation_manager.apply_management(self._agent)

return tool_result

return caller

def _find_normalized_tool_name(self, name: str) -> str:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
tool_registry = self._agent.tool_registry.registry

if tool_registry.get(name, None):
return name

# If the desired name contains underscores, it might be a placeholder for characters that can't be
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
]

# The registry itself defends against similar names, so we can just take the first match
if filtered_tools:
return filtered_tools[0]

raise AttributeError(f"Tool '{name}' not found")

def __init__(
self,
model: Union[Model, str, None] = None,
Expand Down Expand Up @@ -349,7 +240,7 @@ def __init__(
else:
self.state = AgentState()

self.tool_caller = Agent.ToolCaller(self)
self.tool_caller = ToolCaller(self)
Comment thread
mehtarac marked this conversation as resolved.
Outdated

self.hooks = HookRegistry()

Expand Down
Loading
Loading