Skip to content

Commit 131f30e

Browse files
committed
use run state for accurate tracking
1 parent fbb03e3 commit 131f30e

1 file changed

Lines changed: 27 additions & 19 deletions

File tree

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def _build_native_tool_map() -> dict:
3636
_NATIVE_OPENAI_TOOLS = _build_native_tool_map()
3737

3838

39+
class _RunState:
40+
"""Mutable state shared across handoff and tool callbacks during a single run."""
41+
42+
def __init__(self, last_handoff_ns: int, last_node_key: str) -> None:
43+
self.last_handoff_ns = last_handoff_ns
44+
self.last_node_key = last_node_key
45+
46+
3947
class OpenAIAgentGraphRunner(AgentGraphRunner):
4048
"""
4149
AgentGraphRunner implementation for the OpenAI Agents SDK.
@@ -71,17 +79,17 @@ async def run(self, input: Any) -> AgentGraphResult:
7179
tracker = self._graph.get_tracker()
7280
path: List[str] = []
7381
root_node = self._graph.root()
74-
if root_node:
75-
path.append(root_node.get_key())
82+
root_key = root_node.get_key() if root_node else ''
83+
if root_key:
84+
path.append(root_key)
7685

7786
start_ns = time.perf_counter_ns()
78-
# Mutable cell so handoff callbacks can update time-between-handoffs without globals.
79-
last_handoff_ns: List[int] = [start_ns]
87+
state = _RunState(last_handoff_ns=start_ns, last_node_key=root_key)
8088
try:
8189
from agents import Runner
82-
root_agent = self._build_agents(path, last_handoff_ns)
90+
root_agent = self._build_agents(path, state)
8391
result = await Runner.run(root_agent, str(input))
84-
self._flush_final_segment(path, last_handoff_ns, tracker, result)
92+
self._flush_final_segment(state, tracker, result)
8593

8694
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
8795

@@ -123,24 +131,22 @@ async def run(self, input: Any) -> AgentGraphResult:
123131

124132
def _flush_final_segment(
125133
self,
126-
path: List[str],
127-
last_handoff_ns: List[int],
134+
state: _RunState,
128135
tracker: Any,
129136
result: Any,
130137
) -> None:
131138
"""Record duration/tokens for the last active agent (no handoff after it)."""
132-
if not path:
139+
if not state.last_node_key:
133140
return
134-
last_key = path[-1]
135-
node = self._graph.get_node(last_key)
141+
node = self._graph.get_node(state.last_node_key)
136142
if node is None:
137143
return
138144
config_tracker = node.get_config().tracker
139145
if config_tracker is None:
140146
return
141147

142148
now_ns = time.perf_counter_ns()
143-
duration_ms = (now_ns - last_handoff_ns[0]) // 1_000_000
149+
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
144150

145151
usage: Optional[TokenUsage] = None
146152
try:
@@ -167,16 +173,17 @@ def _handle_handoff(
167173
path: List[str],
168174
tracker: Any,
169175
config_tracker: Any,
170-
last_handoff_ns: List[int],
176+
state: _RunState,
171177
) -> None:
172178
path.append(tgt)
179+
state.last_node_key = tgt
173180
if tracker:
174181
tracker.track_handoff_success(src, tgt)
175182

176183
usage: Optional[TokenUsage] = None
177184
now_ns = time.perf_counter_ns()
178-
duration_ms = (now_ns - last_handoff_ns[0]) // 1_000_000
179-
last_handoff_ns[0] = now_ns
185+
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
186+
state.last_handoff_ns = now_ns
180187
try:
181188
usage_entry = run_ctx.usage.request_usage_entries[-1]
182189
usage = TokenUsage(
@@ -202,22 +209,23 @@ def _make_on_handoff(
202209
path: List[str],
203210
tracker: Any,
204211
config_tracker: Any,
205-
last_handoff_ns: List[int],
212+
state: _RunState,
206213
):
207214
def on_handoff(run_ctx: Any) -> None:
208215
self._handle_handoff(
209-
run_ctx, src, tgt, path, tracker, config_tracker, last_handoff_ns
216+
run_ctx, src, tgt, path, tracker, config_tracker, state
210217
)
211218
return on_handoff
212219

213-
def _build_agents(self, path: List[str], last_handoff_ns: List[int]) -> Any:
220+
def _build_agents(self, path: List[str], state: _RunState) -> Any:
214221
"""
215222
Build the agent tree from the graph definition via reverse_traverse.
216223
217224
Agents are constructed from terminal nodes upward so that handoff
218225
targets exist before the agents that hand off to them.
219226
220227
:param path: Mutable list to accumulate the execution path
228+
:param state: Shared run state for tracking handoff timing and last node
221229
:return: The root Agent instance
222230
"""
223231
try:
@@ -262,7 +270,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
262270
path,
263271
tracker,
264272
config_tracker,
265-
last_handoff_ns,
273+
state,
266274
),
267275
)
268276
)

0 commit comments

Comments
 (0)