Skip to content

Commit 388b7af

Browse files
authored
feat: Update OpenAI graph runner to return AgentGraphRunnerResult with GraphMetrics (#155)
1 parent 20a5020 commit 388b7af

3 files changed

Lines changed: 191 additions & 133 deletions

File tree

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 63 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import re
22
import time
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List
44

55
from ldai import log
66
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
7-
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
8-
from ldai.providers.types import LDAIMetrics
9-
from ldai.tracker import TokenUsage
7+
from ldai.providers import AgentGraphRunner, ToolRegistry
8+
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics, LDAIMetrics
109

1110
from ldai_openai.openai_helper import (
1211
extract_usage_from_request_entry,
@@ -39,9 +38,10 @@ class OpenAIAgentGraphRunner(AgentGraphRunner):
3938
4039
AgentGraphRunner implementation for the OpenAI Agents SDK.
4140
42-
Runs the agent graph with the OpenAI Agents SDK and automatically records
43-
graph- and node-level AI metric data to the LaunchDarkly trackers on the
44-
graph definition and each node.
41+
Runs the agent graph with the OpenAI Agents SDK and collects graph- and
42+
node-level metrics. Tracking events are emitted by the managed layer
43+
(:class:`~ldai.ManagedAgentGraph`) from the returned
44+
:class:`~ldai.providers.types.AgentGraphRunnerResult`.
4545
4646
Requires ``openai-agents`` to be installed.
4747
"""
@@ -61,20 +61,20 @@ def __init__(
6161
self._tools = tools
6262
self._agent_name_map: Dict[str, str] = {}
6363
self._tool_name_map: Dict[str, str] = {}
64-
self._node_trackers: Dict[str, Any] = {}
64+
self._node_metrics: Dict[str, LDAIMetrics] = {}
6565

66-
async def run(self, input: Any) -> AgentGraphResult:
66+
async def run(self, input: Any) -> AgentGraphRunnerResult:
6767
"""
6868
Run the agent graph with the given input.
6969
7070
Builds the agent tree via reverse_traverse, then invokes the root
71-
agent with Runner.run(). Tracks path, latency, and invocation
72-
success/failure.
71+
agent with Runner.run(). Collects path, latency, and per-node metrics.
72+
Graph-level tracking events are emitted by the managed layer.
7373
7474
:param input: The string prompt to send to the agent graph
75-
:return: AgentGraphResult with the final output and metrics
75+
:return: AgentGraphRunnerResult with the final content and GraphMetrics
7676
"""
77-
tracker = self._graph.create_tracker()
77+
self._node_metrics = {}
7878
path: List[str] = []
7979
root_node = self._graph.root()
8080
root_key = root_node.get_key() if root_node else ''
@@ -86,24 +86,26 @@ async def run(self, input: Any) -> AgentGraphResult:
8686
state = _RunState(last_handoff_ns=start_ns, last_node_key=root_key)
8787
try:
8888
from agents import Runner
89-
root_agent = self._build_agents(path, state, tracker)
89+
root_agent = self._build_agents(path, state)
90+
if root_key:
91+
self._node_metrics[root_key] = LDAIMetrics(success=False)
9092
result = await Runner.run(root_agent, input_str)
9193
self._flush_final_segment(state, result)
92-
self._track_tool_calls(result)
94+
self._collect_tool_calls(result)
9395

94-
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
96+
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
9597
token_usage = get_ai_usage_from_response(result)
9698

97-
tracker.track_path(path)
98-
tracker.track_duration(duration)
99-
tracker.track_invocation_success()
100-
if token_usage is not None:
101-
tracker.track_total_tokens(token_usage)
102-
103-
return AgentGraphResult(
104-
output=str(result.final_output),
99+
return AgentGraphRunnerResult(
100+
content=str(result.final_output),
105101
raw=result,
106-
metrics=LDAIMetrics(success=True, usage=token_usage),
102+
metrics=GraphMetrics(
103+
success=True,
104+
path=path,
105+
duration_ms=duration_ms,
106+
usage=token_usage,
107+
node_metrics=self._node_metrics,
108+
),
107109
)
108110
except Exception as exc:
109111
if isinstance(exc, ImportError):
@@ -113,17 +115,20 @@ async def run(self, input: Any) -> AgentGraphResult:
113115
)
114116
else:
115117
log.warning(f'OpenAIAgentGraphRunner run failed: {exc}')
116-
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
117-
tracker.track_duration(duration)
118-
tracker.track_invocation_failure()
119-
return AgentGraphResult(
120-
output='',
118+
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
119+
return AgentGraphRunnerResult(
120+
content='',
121121
raw=None,
122-
metrics=LDAIMetrics(success=False),
122+
metrics=GraphMetrics(
123+
success=False,
124+
path=path,
125+
duration_ms=duration_ms,
126+
node_metrics=self._node_metrics,
127+
),
123128
)
124129

125130
def _build_agents(
126-
self, path: List[str], state: _RunState, tracker: Any
131+
self, path: List[str], state: _RunState
127132
) -> Any:
128133
"""
129134
Build the agent tree from the graph definition via reverse_traverse.
@@ -133,7 +138,6 @@ def _build_agents(
133138
134139
:param path: Mutable list to accumulate the execution path
135140
:param state: Shared run state for tracking handoff timing and last node
136-
:param tracker: Graph-level tracker shared across the entire run
137141
:return: The root Agent instance
138142
"""
139143
try:
@@ -151,12 +155,9 @@ def _build_agents(
151155

152156
name_map: Dict[str, str] = {}
153157
tool_name_map: Dict[str, str] = {}
154-
node_trackers: Dict[str, Any] = {}
155158

156159
def build_node(node: AgentGraphNode, ctx: dict) -> Any:
157160
node_config = node.get_config()
158-
config_tracker = node_config.create_tracker()
159-
node_trackers[node_config.key] = config_tracker
160161
model = node_config.model
161162

162163
if not model:
@@ -177,8 +178,6 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
177178
node_config.key,
178179
target_key,
179180
path,
180-
tracker,
181-
config_tracker,
182181
state,
183182
),
184183
)
@@ -212,20 +211,17 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
212211
root = self._graph.reverse_traverse(fn=build_node)
213212
self._agent_name_map = name_map
214213
self._tool_name_map = tool_name_map
215-
self._node_trackers = node_trackers
216214
return root
217215

218216
def _make_on_handoff(
219217
self,
220218
src: str,
221219
tgt: str,
222220
path: List[str],
223-
tracker: Any,
224-
config_tracker: Any,
225221
state: _RunState,
226222
):
227223
def on_handoff(run_ctx: Any) -> None:
228-
self._handle_handoff(run_ctx, src, tgt, path, tracker, config_tracker, state)
224+
self._handle_handoff(run_ctx, src, tgt, path, state)
229225
return on_handoff
230226

231227
def _handle_handoff(
@@ -234,64 +230,57 @@ def _handle_handoff(
234230
src: str,
235231
tgt: str,
236232
path: List[str],
237-
tracker: Any,
238-
config_tracker: Any,
239233
state: _RunState,
240234
) -> None:
241235
path.append(tgt)
242-
state.last_node_key = tgt
243-
tracker.track_handoff_success(src, tgt)
244236

245237
now_ns = time.perf_counter_ns()
246238
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
247239
state.last_handoff_ns = now_ns
248240

249-
usage: Optional[TokenUsage] = None
250-
try:
251-
usage = extract_usage_from_request_entry(
252-
run_ctx.usage.request_usage_entries[-1]
253-
)
254-
except Exception:
255-
pass
241+
src_metrics = self._node_metrics.get(src)
242+
if src_metrics is not None:
243+
src_metrics.success = True
244+
src_metrics.duration_ms = int(duration_ms)
245+
try:
246+
src_metrics.usage = extract_usage_from_request_entry(
247+
run_ctx.usage.request_usage_entries[-1]
248+
)
249+
except Exception:
250+
pass
256251

257-
if config_tracker is not None:
258-
if usage is not None:
259-
config_tracker.track_tokens(usage)
260-
if duration_ms is not None:
261-
config_tracker.track_duration(int(duration_ms))
262-
config_tracker.track_success()
252+
self._node_metrics[tgt] = LDAIMetrics(success=False)
253+
state.last_node_key = tgt
263254

264255
def _flush_final_segment(self, state: _RunState, result: Any) -> None:
265256
"""Record duration/tokens for the last active agent (no handoff after it)."""
266257
if not state.last_node_key:
267258
return
268-
config_tracker = self._node_trackers.get(state.last_node_key)
269-
if config_tracker is None:
259+
metrics = self._node_metrics.get(state.last_node_key)
260+
if metrics is None:
270261
return
271262

263+
metrics.success = True
272264
now_ns = time.perf_counter_ns()
273-
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
265+
metrics.duration_ms = int((now_ns - state.last_handoff_ns) // 1_000_000)
274266

275-
usage: Optional[TokenUsage] = None
276267
try:
277-
usage = extract_usage_from_request_entry(
268+
metrics.usage = extract_usage_from_request_entry(
278269
result.context_wrapper.usage.request_usage_entries[-1]
279270
)
280271
except Exception:
281272
pass
282273

283-
if usage is not None:
284-
config_tracker.track_tokens(usage)
285-
config_tracker.track_duration(int(duration_ms))
286-
config_tracker.track_success()
287-
288-
def _track_tool_calls(self, result: Any) -> None:
289-
"""Track all tool calls from the run result, attributed to the node that called them."""
274+
def _collect_tool_calls(self, result: Any) -> None:
275+
"""Collect all tool calls from the run result, attributed to the node that called them."""
290276
for agent_name, tool_fn_name in get_tool_calls_from_run_items(result.new_items):
291277
agent_key = self._agent_name_map.get(agent_name, agent_name)
292278
tool_name = self._tool_name_map.get(tool_fn_name)
293279
if tool_name is None:
294280
continue
295-
config_tracker = self._node_trackers.get(agent_key)
296-
if config_tracker is not None:
297-
config_tracker.track_tool_call(tool_name)
281+
metrics = self._node_metrics.get(agent_key)
282+
if metrics is not None:
283+
if metrics.tool_calls is None:
284+
metrics.tool_calls = [tool_name]
285+
else:
286+
metrics.tool_calls.append(tool_name)

0 commit comments

Comments
 (0)