forked from openai/openai-agents-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool_use_tracker.py
More file actions
152 lines (124 loc) · 5.19 KB
/
tool_use_tracker.py
File metadata and controls
152 lines (124 loc) · 5.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Tool-use tracking utilities. Hosts AgentToolUseTracker and helpers to serialize/deserialize
its state plus lightweight tool-call type utilities. Internal use only.
"""
from __future__ import annotations
from typing import Any, get_args, get_origin
from .._tool_identity import get_function_tool_trace_name
from ..agent import Agent
from ..items import (
HandoffCallItem,
ToolCallItem,
ToolCallItemTypes,
ToolCallOutputItem,
ToolSearchCallItem,
ToolSearchOutputItem,
)
from ..run_state import _build_agent_map
from .run_steps import ProcessedResponse, ToolRunFunction
__all__ = [
"AgentToolUseTracker",
"serialize_tool_use_tracker",
"hydrate_tool_use_tracker",
"get_tool_call_types",
"TOOL_CALL_TYPES",
]
_TOOL_USE_RESET_TRACKING_ITEM_TYPES = (
HandoffCallItem,
ToolCallItem,
ToolCallOutputItem,
)
_PROCESSED_RESPONSE_TOOL_ITEM_TYPES = (
HandoffCallItem,
ToolCallItem,
ToolCallOutputItem,
ToolSearchCallItem,
ToolSearchOutputItem,
)
class AgentToolUseTracker:
"""Track which tools an agent has used to support model_settings resets."""
def __init__(self) -> None:
# Name-keyed map is used for serialization/hydration only.
self.agent_map: dict[str, set[str]] = {}
# Instance-keyed list is used for runtime checks.
self.agent_to_tools: list[tuple[Agent[Any], list[str]]] = []
def record_used_tools(self, agent: Agent[Any], tools: list[ToolRunFunction]) -> None:
tool_names = [
get_function_tool_trace_name(tool.function_tool) or tool.function_tool.name
for tool in tools
]
self.add_tool_use(agent, tool_names)
def record_processed_response(
self, agent: Agent[Any], processed_response: ProcessedResponse
) -> None:
"""Track resettable tool usage from a processed model response."""
tool_name_iter = iter(processed_response.tools_used)
tool_names: list[str] = []
for item in processed_response.new_items:
if not isinstance(item, _PROCESSED_RESPONSE_TOOL_ITEM_TYPES):
continue
tool_name = next(tool_name_iter, None)
if tool_name is None:
break
if isinstance(item, _TOOL_USE_RESET_TRACKING_ITEM_TYPES):
tool_names.append(tool_name)
self.add_tool_use(agent, tool_names)
def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
"""Maintain compatibility for callers that append tool usage directly."""
if not tool_names:
return
agent_name = getattr(agent, "name", agent.__class__.__name__)
names_set = self.agent_map.setdefault(agent_name, set())
names_set.update(tool_names)
existing = next((item for item in self.agent_to_tools if item[0] is agent), None)
if existing:
existing[1].extend(tool_names)
else:
self.agent_to_tools.append((agent, list(tool_names)))
def has_used_tools(self, agent: Agent[Any]) -> bool:
existing = next((item for item in self.agent_to_tools if item[0] is agent), None)
return bool(existing and existing[1])
def as_serializable(self) -> dict[str, list[str]]:
if self.agent_map:
return {name: sorted(tool_names) for name, tool_names in self.agent_map.items()}
snapshot: dict[str, set[str]] = {}
for agent, names in self.agent_to_tools:
agent_name = getattr(agent, "name", agent.__class__.__name__)
snapshot.setdefault(agent_name, set()).update(names)
return {name: sorted(tool_names) for name, tool_names in snapshot.items()}
@classmethod
def from_serializable(cls, data: dict[str, list[str]]) -> AgentToolUseTracker:
tracker = cls()
tracker.agent_map = {name: set(tools) for name, tools in data.items()}
return tracker
def serialize_tool_use_tracker(tool_use_tracker: AgentToolUseTracker) -> dict[str, list[str]]:
"""Convert the AgentToolUseTracker into a serializable snapshot."""
snapshot: dict[str, list[str]] = {}
for agent, tool_names in tool_use_tracker.agent_to_tools:
snapshot[agent.name] = list(tool_names)
return snapshot
def hydrate_tool_use_tracker(
tool_use_tracker: AgentToolUseTracker,
run_state: Any,
starting_agent: Agent[Any],
) -> None:
"""Seed a fresh AgentToolUseTracker using the snapshot stored on the RunState."""
snapshot = run_state.get_tool_use_tracker_snapshot()
if not snapshot:
return
agent_map, _ = _build_agent_map(starting_agent)
for agent_name, tool_names in snapshot.items():
agent = agent_map.get(agent_name)
if agent is None:
continue
tool_use_tracker.add_tool_use(agent, list(tool_names))
def get_tool_call_types() -> tuple[type, ...]:
"""Return the concrete classes that represent tool call outputs."""
normalized_types: list[type] = []
for type_hint in get_args(ToolCallItemTypes):
origin = get_origin(type_hint)
candidate = origin or type_hint
if isinstance(candidate, type):
normalized_types.append(candidate)
return tuple(normalized_types)
TOOL_CALL_TYPES: tuple[type, ...] = get_tool_call_types()