Skip to content

Commit 3d43c96

Browse files
committed
simplify the node metrics collection
1 parent f6df746 commit 3d43c96

2 files changed

Lines changed: 25 additions & 62 deletions

File tree

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import re
22
import time
3-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List
44

55
from ldai import log
66
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
77
from ldai.providers import AgentGraphRunner, ToolRegistry
88
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics, LDAIMetrics
9-
from ldai.tracker import TokenUsage
109

1110
from ldai_openai.openai_helper import (
1211
extract_usage_from_request_entry,
@@ -22,34 +21,6 @@ def _sanitize_agent_name(key: str) -> str:
2221
return re.sub(r'[^a-zA-Z0-9_]', '_', key)
2322

2423

25-
class _NodeMetricsAccumulator:
26-
"""Mutable per-node metrics collected during a run (replaces LDAIConfigTracker)."""
27-
28-
def __init__(self) -> None:
29-
self.usage: Optional[TokenUsage] = None
30-
self.duration_ms: Optional[int] = None
31-
self.tool_calls: List[str] = []
32-
self.success: bool = True
33-
34-
def set_usage(self, usage: Optional[TokenUsage]) -> None:
35-
if usage is not None:
36-
self.usage = usage
37-
38-
def set_duration_ms(self, duration_ms: int) -> None:
39-
self.duration_ms = duration_ms
40-
41-
def add_tool_call(self, tool_name: str) -> None:
42-
self.tool_calls.append(tool_name)
43-
44-
def to_ldai_metrics(self) -> LDAIMetrics:
45-
return LDAIMetrics(
46-
success=self.success,
47-
usage=self.usage,
48-
duration_ms=self.duration_ms,
49-
tool_calls=self.tool_calls if self.tool_calls else None,
50-
)
51-
52-
5324
class _RunState:
5425
"""Mutable state shared across handoff and tool callbacks during a single run."""
5526

@@ -90,7 +61,7 @@ def __init__(
9061
self._tools = tools
9162
self._agent_name_map: Dict[str, str] = {}
9263
self._tool_name_map: Dict[str, str] = {}
93-
self._node_accumulators: Dict[str, _NodeMetricsAccumulator] = {}
64+
self._node_metrics: Dict[str, LDAIMetrics] = {}
9465

9566
async def run(self, input: Any) -> AgentGraphRunnerResult:
9667
"""
@@ -122,11 +93,6 @@ async def run(self, input: Any) -> AgentGraphRunnerResult:
12293
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
12394
token_usage = get_ai_usage_from_response(result)
12495

125-
node_metrics = {
126-
key: acc.to_ldai_metrics()
127-
for key, acc in self._node_accumulators.items()
128-
}
129-
13096
return AgentGraphRunnerResult(
13197
content=str(result.final_output),
13298
raw=result,
@@ -135,7 +101,7 @@ async def run(self, input: Any) -> AgentGraphRunnerResult:
135101
path=path,
136102
duration_ms=duration_ms,
137103
usage=token_usage,
138-
node_metrics=node_metrics,
104+
node_metrics=self._node_metrics,
139105
),
140106
)
141107
except Exception as exc:
@@ -185,12 +151,12 @@ def _build_agents(
185151

186152
name_map: Dict[str, str] = {}
187153
tool_name_map: Dict[str, str] = {}
188-
node_accumulators: Dict[str, _NodeMetricsAccumulator] = {}
154+
node_metrics: Dict[str, LDAIMetrics] = {}
189155

190156
def build_node(node: AgentGraphNode, ctx: dict) -> Any:
191157
node_config = node.get_config()
192-
acc = _NodeMetricsAccumulator()
193-
node_accumulators[node_config.key] = acc
158+
metrics = LDAIMetrics(success=True)
159+
node_metrics[node_config.key] = metrics
194160
model = node_config.model
195161

196162
if not model:
@@ -211,7 +177,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
211177
node_config.key,
212178
target_key,
213179
path,
214-
acc,
180+
metrics,
215181
state,
216182
),
217183
)
@@ -245,19 +211,19 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
245211
root = self._graph.reverse_traverse(fn=build_node)
246212
self._agent_name_map = name_map
247213
self._tool_name_map = tool_name_map
248-
self._node_accumulators = node_accumulators
214+
self._node_metrics = node_metrics
249215
return root
250216

251217
def _make_on_handoff(
252218
self,
253219
src: str,
254220
tgt: str,
255221
path: List[str],
256-
acc: _NodeMetricsAccumulator,
222+
metrics: LDAIMetrics,
257223
state: _RunState,
258224
):
259225
def on_handoff(run_ctx: Any) -> None:
260-
self._handle_handoff(run_ctx, src, tgt, path, acc, state)
226+
self._handle_handoff(run_ctx, src, tgt, path, metrics, state)
261227
return on_handoff
262228

263229
def _handle_handoff(
@@ -266,7 +232,7 @@ def _handle_handoff(
266232
src: str,
267233
tgt: str,
268234
path: List[str],
269-
acc: _NodeMetricsAccumulator,
235+
metrics: LDAIMetrics,
270236
state: _RunState,
271237
) -> None:
272238
path.append(tgt)
@@ -276,46 +242,43 @@ def _handle_handoff(
276242
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
277243
state.last_handoff_ns = now_ns
278244

279-
usage: Optional[TokenUsage] = None
280245
try:
281-
usage = extract_usage_from_request_entry(
246+
metrics.usage = extract_usage_from_request_entry(
282247
run_ctx.usage.request_usage_entries[-1]
283248
)
284249
except Exception:
285250
pass
286251

287-
acc.set_usage(usage)
288-
acc.set_duration_ms(int(duration_ms))
252+
metrics.duration_ms = int(duration_ms)
289253

290254
def _flush_final_segment(self, state: _RunState, result: Any) -> None:
291255
"""Record duration/tokens for the last active agent (no handoff after it)."""
292256
if not state.last_node_key:
293257
return
294-
acc = self._node_accumulators.get(state.last_node_key)
295-
if acc is None:
258+
metrics = self._node_metrics.get(state.last_node_key)
259+
if metrics is None:
296260
return
297261

298262
now_ns = time.perf_counter_ns()
299-
duration_ms = (now_ns - state.last_handoff_ns) // 1_000_000
263+
metrics.duration_ms = int((now_ns - state.last_handoff_ns) // 1_000_000)
300264

301-
usage: Optional[TokenUsage] = None
302265
try:
303-
usage = extract_usage_from_request_entry(
266+
metrics.usage = extract_usage_from_request_entry(
304267
result.context_wrapper.usage.request_usage_entries[-1]
305268
)
306269
except Exception:
307270
pass
308271

309-
acc.set_usage(usage)
310-
acc.set_duration_ms(int(duration_ms))
311-
312272
def _collect_tool_calls(self, result: Any) -> None:
313273
"""Collect all tool calls from the run result, attributed to the node that called them."""
314274
for agent_name, tool_fn_name in get_tool_calls_from_run_items(result.new_items):
315275
agent_key = self._agent_name_map.get(agent_name, agent_name)
316276
tool_name = self._tool_name_map.get(tool_fn_name)
317277
if tool_name is None:
318278
continue
319-
acc = self._node_accumulators.get(agent_key)
320-
if acc is not None:
321-
acc.add_tool_call(tool_name)
279+
metrics = self._node_metrics.get(agent_key)
280+
if metrics is not None:
281+
if metrics.tool_calls is None:
282+
metrics.tool_calls = [tool_name]
283+
else:
284+
metrics.tool_calls.append(tool_name)

packages/ai-providers/server-ai-openai/tests/test_openai_agent_graph_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,5 @@ async def test_openai_agent_graph_runner_run_success():
151151
node_factory = graph.get_node('root-agent').get_config().create_tracker
152152
node_factory.assert_not_called()
153153

154-
# Runner accumulates per-node metrics in _node_accumulators
155-
assert 'root-agent' in runner._node_accumulators
154+
# Runner accumulates per-node metrics in _node_metrics
155+
assert 'root-agent' in runner._node_metrics

0 commit comments

Comments
 (0)