Skip to content

Commit 8b2d12a

Browse files
committed
address code review comments
1 parent 65ba07c commit 8b2d12a

File tree

4 files changed

+46
-31
lines changed

4 files changed

+46
-31
lines changed

src/agents/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,10 +537,10 @@ async def dispatch_stream_events() -> None:
537537

538538
return run_result.final_output
539539

540-
# Set origin tracking on the FunctionTool created by @function_tool
540+
# Set origin tracking on run_agent (the FunctionTool returned by @function_tool)
541541
run_agent._tool_origin = ToolOrigin(
542542
type=ToolOriginType.AGENT_AS_TOOL,
543-
agent_as_tool_name=self.name,
543+
agent_as_tool=self,
544544
)
545545
return run_agent
546546

src/agents/mcp/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def to_function_tool(
179179
)
180180
function_tool._tool_origin = ToolOrigin(
181181
type=ToolOriginType.MCP,
182-
mcp_server_name=server.name,
182+
mcp_server=server,
183183
)
184184
return function_tool
185185

src/agents/tool.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
if TYPE_CHECKING:
4949
from .agent import Agent, AgentBase
5050
from .items import RunItem
51+
from .mcp.server import MCPServer
5152

5253

5354
ToolParams = ParamSpec("ToolParams")
@@ -200,19 +201,19 @@ class ToolOrigin:
200201
type: ToolOriginType
201202
"""The type of tool origin."""
202203

203-
mcp_server_name: str | None = None
204-
"""The name of the MCP server. Only set when type is MCP."""
204+
mcp_server: MCPServer | None = None
205+
"""The MCP server object. Only set when type is MCP."""
205206

206-
agent_as_tool_name: str | None = None
207-
"""The name of the agent. Only set when type is AGENT_AS_TOOL."""
207+
agent_as_tool: Agent[Any] | None = None
208+
"""The agent object. Only set when type is AGENT_AS_TOOL."""
208209

209210
def __repr__(self) -> str:
210211
"""Custom repr that only includes relevant fields."""
211212
parts = [f"type={self.type.value!r}"]
212-
if self.mcp_server_name is not None:
213-
parts.append(f"mcp_server_name={self.mcp_server_name!r}")
214-
if self.agent_as_tool_name is not None:
215-
parts.append(f"agent_as_tool_name={self.agent_as_tool_name!r}")
213+
if self.mcp_server is not None:
214+
parts.append(f"mcp_server_name={self.mcp_server.name!r}")
215+
if self.agent_as_tool is not None:
216+
parts.append(f"agent_as_tool_name={self.agent_as_tool.name!r}")
216217
return f"ToolOrigin({', '.join(parts)})"
217218

218219

tests/test_tool_origin.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ def test_tool(x: int) -> str:
4444
assert len(tool_call_items) == 1
4545
assert tool_call_items[0].tool_origin is not None
4646
assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION
47-
assert tool_call_items[0].tool_origin.mcp_server_name is None
48-
assert tool_call_items[0].tool_origin.agent_as_tool_name is None
47+
assert tool_call_items[0].tool_origin.mcp_server is None
48+
assert tool_call_items[0].tool_origin.agent_as_tool is None
4949

5050
assert len(tool_output_items) == 1
5151
assert tool_output_items[0].tool_origin is not None
5252
assert tool_output_items[0].tool_origin.type == ToolOriginType.FUNCTION
53-
assert tool_output_items[0].tool_origin.mcp_server_name is None
54-
assert tool_output_items[0].tool_origin.agent_as_tool_name is None
53+
assert tool_output_items[0].tool_origin.mcp_server is None
54+
assert tool_output_items[0].tool_origin.agent_as_tool is None
5555

5656

5757
@pytest.mark.asyncio
@@ -78,14 +78,16 @@ async def test_mcp_tool_origin():
7878
assert len(tool_call_items) == 1
7979
assert tool_call_items[0].tool_origin is not None
8080
assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP
81-
assert tool_call_items[0].tool_origin.mcp_server_name == "test_mcp_server"
82-
assert tool_call_items[0].tool_origin.agent_as_tool_name is None
81+
assert tool_call_items[0].tool_origin.mcp_server is not None
82+
assert tool_call_items[0].tool_origin.mcp_server.name == "test_mcp_server"
83+
assert tool_call_items[0].tool_origin.agent_as_tool is None
8384

8485
assert len(tool_output_items) == 1
8586
assert tool_output_items[0].tool_origin is not None
8687
assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP
87-
assert tool_output_items[0].tool_origin.mcp_server_name == "test_mcp_server"
88-
assert tool_output_items[0].tool_origin.agent_as_tool_name is None
88+
assert tool_output_items[0].tool_origin.mcp_server is not None
89+
assert tool_output_items[0].tool_origin.mcp_server.name == "test_mcp_server"
90+
assert tool_output_items[0].tool_origin.agent_as_tool is None
8991

9092

9193
@pytest.mark.asyncio
@@ -123,14 +125,16 @@ async def test_agent_as_tool_origin():
123125
assert len(tool_call_items) == 1
124126
assert tool_call_items[0].tool_origin is not None
125127
assert tool_call_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL
126-
assert tool_call_items[0].tool_origin.mcp_server_name is None
127-
assert tool_call_items[0].tool_origin.agent_as_tool_name == "nested_agent"
128+
assert tool_call_items[0].tool_origin.mcp_server is None
129+
assert tool_call_items[0].tool_origin.agent_as_tool is not None
130+
assert tool_call_items[0].tool_origin.agent_as_tool.name == "nested_agent"
128131

129132
assert len(tool_output_items) == 1
130133
assert tool_output_items[0].tool_origin is not None
131134
assert tool_output_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL
132-
assert tool_output_items[0].tool_origin.mcp_server_name is None
133-
assert tool_output_items[0].tool_origin.agent_as_tool_name == "nested_agent"
135+
assert tool_output_items[0].tool_origin.mcp_server is None
136+
assert tool_output_items[0].tool_origin.agent_as_tool is not None
137+
assert tool_output_items[0].tool_origin.agent_as_tool.name == "nested_agent"
134138

135139

136140
@pytest.mark.asyncio
@@ -192,10 +196,12 @@ def func_tool(x: int) -> str:
192196
assert function_item.tool_origin.type == ToolOriginType.FUNCTION
193197
assert mcp_item.tool_origin is not None
194198
assert mcp_item.tool_origin.type == ToolOriginType.MCP
195-
assert mcp_item.tool_origin.mcp_server_name == "mcp_server"
199+
assert mcp_item.tool_origin.mcp_server is not None
200+
assert mcp_item.tool_origin.mcp_server.name == "mcp_server"
196201
assert agent_item.tool_origin is not None
197202
assert agent_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL
198-
assert agent_item.tool_origin.agent_as_tool_name == "nested"
203+
assert agent_item.tool_origin.agent_as_tool is not None
204+
assert agent_item.tool_origin.agent_as_tool.name == "nested"
199205

200206

201207
@pytest.mark.asyncio
@@ -229,12 +235,14 @@ async def test_tool_origin_streaming():
229235
assert len(tool_call_items) == 1
230236
assert tool_call_items[0].tool_origin is not None
231237
assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP
232-
assert tool_call_items[0].tool_origin.mcp_server_name == "streaming_server"
238+
assert tool_call_items[0].tool_origin.mcp_server is not None
239+
assert tool_call_items[0].tool_origin.mcp_server.name == "streaming_server"
233240

234241
assert len(tool_output_items) == 1
235242
assert tool_output_items[0].tool_origin is not None
236243
assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP
237-
assert tool_output_items[0].tool_origin.mcp_server_name == "streaming_server"
244+
assert tool_output_items[0].tool_origin.mcp_server is not None
245+
assert tool_output_items[0].tool_origin.mcp_server.name == "streaming_server"
238246

239247

240248
@pytest.mark.asyncio
@@ -246,12 +254,18 @@ async def test_tool_origin_repr():
246254
assert "agent_as_tool_name" not in repr(function_origin)
247255

248256
# MCP origin
249-
mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server_name="test_server")
250-
assert "mcp_server_name='test_server'" in repr(mcp_origin)
251-
assert "agent_as_tool_name" not in repr(mcp_origin)
257+
if sys.version_info >= (3, 10):
258+
from .mcp.helpers import FakeMCPServer
259+
260+
test_server = FakeMCPServer(server_name="test_server")
261+
mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server=test_server)
262+
assert "mcp_server_name='test_server'" in repr(mcp_origin)
263+
assert "agent_as_tool_name" not in repr(mcp_origin)
252264

253265
# AGENT_AS_TOOL origin
254-
agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool_name="test_agent")
266+
model = FakeModel()
267+
test_agent = Agent(name="test_agent", model=model, instructions="Test agent")
268+
agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool=test_agent)
255269
assert "agent_as_tool_name='test_agent'" in repr(agent_origin)
256270
assert "mcp_server_name" not in repr(agent_origin)
257271

0 commit comments

Comments
 (0)