11import re
22import time
3- from typing import Any , Dict , List , Optional
3+ from typing import Any , Dict , List
44
55from ldai import log
66from ldai .agent_graph import AgentGraphDefinition , AgentGraphNode
77from ldai .providers import AgentGraphRunner , ToolRegistry
88from ldai .providers .types import AgentGraphRunnerResult , GraphMetrics , LDAIMetrics
9- from ldai .tracker import TokenUsage
109
1110from ldai_openai .openai_helper import (
1211 extract_usage_from_request_entry ,
@@ -22,34 +21,6 @@ def _sanitize_agent_name(key: str) -> str:
2221 return re .sub (r'[^a-zA-Z0-9_]' , '_' , key )
2322
2423
25- class _NodeMetricsAccumulator :
26- """Mutable per-node metrics collected during a run (replaces LDAIConfigTracker)."""
27-
28- def __init__ (self ) -> None :
29- self .usage : Optional [TokenUsage ] = None
30- self .duration_ms : Optional [int ] = None
31- self .tool_calls : List [str ] = []
32- self .success : bool = True
33-
34- def set_usage (self , usage : Optional [TokenUsage ]) -> None :
35- if usage is not None :
36- self .usage = usage
37-
38- def set_duration_ms (self , duration_ms : int ) -> None :
39- self .duration_ms = duration_ms
40-
41- def add_tool_call (self , tool_name : str ) -> None :
42- self .tool_calls .append (tool_name )
43-
44- def to_ldai_metrics (self ) -> LDAIMetrics :
45- return LDAIMetrics (
46- success = self .success ,
47- usage = self .usage ,
48- duration_ms = self .duration_ms ,
49- tool_calls = self .tool_calls if self .tool_calls else None ,
50- )
51-
52-
5324class _RunState :
5425 """Mutable state shared across handoff and tool callbacks during a single run."""
5526
@@ -90,7 +61,7 @@ def __init__(
9061 self ._tools = tools
9162 self ._agent_name_map : Dict [str , str ] = {}
9263 self ._tool_name_map : Dict [str , str ] = {}
93- self ._node_accumulators : Dict [str , _NodeMetricsAccumulator ] = {}
64+ self ._node_metrics : Dict [str , LDAIMetrics ] = {}
9465
9566 async def run (self , input : Any ) -> AgentGraphRunnerResult :
9667 """
@@ -122,11 +93,6 @@ async def run(self, input: Any) -> AgentGraphRunnerResult:
12293 duration_ms = (time .perf_counter_ns () - start_ns ) // 1_000_000
12394 token_usage = get_ai_usage_from_response (result )
12495
125- node_metrics = {
126- key : acc .to_ldai_metrics ()
127- for key , acc in self ._node_accumulators .items ()
128- }
129-
13096 return AgentGraphRunnerResult (
13197 content = str (result .final_output ),
13298 raw = result ,
@@ -135,7 +101,7 @@ async def run(self, input: Any) -> AgentGraphRunnerResult:
135101 path = path ,
136102 duration_ms = duration_ms ,
137103 usage = token_usage ,
138- node_metrics = node_metrics ,
104+ node_metrics = self . _node_metrics ,
139105 ),
140106 )
141107 except Exception as exc :
@@ -185,12 +151,12 @@ def _build_agents(
185151
186152 name_map : Dict [str , str ] = {}
187153 tool_name_map : Dict [str , str ] = {}
188- node_accumulators : Dict [str , _NodeMetricsAccumulator ] = {}
154+ node_metrics : Dict [str , LDAIMetrics ] = {}
189155
190156 def build_node (node : AgentGraphNode , ctx : dict ) -> Any :
191157 node_config = node .get_config ()
192- acc = _NodeMetricsAccumulator ( )
193- node_accumulators [node_config .key ] = acc
158+ metrics = LDAIMetrics ( success = True )
159+ node_metrics [node_config .key ] = metrics
194160 model = node_config .model
195161
196162 if not model :
@@ -211,7 +177,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
211177 node_config .key ,
212178 target_key ,
213179 path ,
214- acc ,
180+ metrics ,
215181 state ,
216182 ),
217183 )
@@ -245,19 +211,19 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
245211 root = self ._graph .reverse_traverse (fn = build_node )
246212 self ._agent_name_map = name_map
247213 self ._tool_name_map = tool_name_map
248- self ._node_accumulators = node_accumulators
214+ self ._node_metrics = node_metrics
249215 return root
250216
251217 def _make_on_handoff (
252218 self ,
253219 src : str ,
254220 tgt : str ,
255221 path : List [str ],
256- acc : _NodeMetricsAccumulator ,
222+ metrics : LDAIMetrics ,
257223 state : _RunState ,
258224 ):
259225 def on_handoff (run_ctx : Any ) -> None :
260- self ._handle_handoff (run_ctx , src , tgt , path , acc , state )
226+ self ._handle_handoff (run_ctx , src , tgt , path , metrics , state )
261227 return on_handoff
262228
263229 def _handle_handoff (
@@ -266,7 +232,7 @@ def _handle_handoff(
266232 src : str ,
267233 tgt : str ,
268234 path : List [str ],
269- acc : _NodeMetricsAccumulator ,
235+ metrics : LDAIMetrics ,
270236 state : _RunState ,
271237 ) -> None :
272238 path .append (tgt )
@@ -276,46 +242,43 @@ def _handle_handoff(
276242 duration_ms = (now_ns - state .last_handoff_ns ) // 1_000_000
277243 state .last_handoff_ns = now_ns
278244
279- usage : Optional [TokenUsage ] = None
280245 try :
281- usage = extract_usage_from_request_entry (
246+ metrics . usage = extract_usage_from_request_entry (
282247 run_ctx .usage .request_usage_entries [- 1 ]
283248 )
284249 except Exception :
285250 pass
286251
287- acc .set_usage (usage )
288- acc .set_duration_ms (int (duration_ms ))
252+ metrics .duration_ms = int (duration_ms )
289253
290254 def _flush_final_segment (self , state : _RunState , result : Any ) -> None :
291255 """Record duration/tokens for the last active agent (no handoff after it)."""
292256 if not state .last_node_key :
293257 return
294- acc = self ._node_accumulators .get (state .last_node_key )
295- if acc is None :
258+ metrics = self ._node_metrics .get (state .last_node_key )
259+ if metrics is None :
296260 return
297261
298262 now_ns = time .perf_counter_ns ()
299- duration_ms = ( now_ns - state .last_handoff_ns ) // 1_000_000
263+ metrics . duration_ms = int (( now_ns - state .last_handoff_ns ) // 1_000_000 )
300264
301- usage : Optional [TokenUsage ] = None
302265 try :
303- usage = extract_usage_from_request_entry (
266+ metrics . usage = extract_usage_from_request_entry (
304267 result .context_wrapper .usage .request_usage_entries [- 1 ]
305268 )
306269 except Exception :
307270 pass
308271
309- acc .set_usage (usage )
310- acc .set_duration_ms (int (duration_ms ))
311-
312272 def _collect_tool_calls (self , result : Any ) -> None :
313273 """Collect all tool calls from the run result, attributed to the node that called them."""
314274 for agent_name , tool_fn_name in get_tool_calls_from_run_items (result .new_items ):
315275 agent_key = self ._agent_name_map .get (agent_name , agent_name )
316276 tool_name = self ._tool_name_map .get (tool_fn_name )
317277 if tool_name is None :
318278 continue
319- acc = self ._node_accumulators .get (agent_key )
320- if acc is not None :
321- acc .add_tool_call (tool_name )
279+ metrics = self ._node_metrics .get (agent_key )
280+ if metrics is not None :
281+ if metrics .tool_calls is None :
282+ metrics .tool_calls = [tool_name ]
283+ else :
284+ metrics .tool_calls .append (tool_name )
0 commit comments