-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlanggraph_callback_handler.py
More file actions
234 lines (201 loc) · 8.32 KB
/
langgraph_callback_handler.py
File metadata and controls
234 lines (201 loc) · 8.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import time
from typing import Any, Dict, List, Optional, Set
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import ChatGeneration, LLMResult
from ldai.agent_graph import AgentGraphDefinition
from ldai.tracker import TokenUsage
from ldai_langchain.langchain_helper import get_ai_usage_from_response
class LDMetricsCallbackHandler(BaseCallbackHandler):
"""
CAUTION:
This feature is experimental and should NOT be considered ready for production use.
It may change or be removed without notice and is not subject to backwards
compatibility guarantees.
LangChain callback handler that collects per-node metrics during a LangGraph run.
Records token usage, tool calls, and duration for each agent node in the graph,
then flushes them to LaunchDarkly trackers after the run completes via ``flush()``.
"""
def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
"""
Initialize the handler.
:param node_keys: Set of LangGraph node keys that represent agent nodes
(excludes ``__tools`` suffix nodes).
:param fn_name_to_config_key: Mapping from tool function ``__name__`` to
the LD config key for that tool (e.g. ``'fetch_weather'`` -> ``'get_weather_open_meteo'``).
"""
super().__init__()
self._node_keys = node_keys
self._fn_name_to_config_key = fn_name_to_config_key
# run_id -> node_key for active chain runs
self._run_to_node: Dict[UUID, str] = {}
# accumulated token usage per node
self._node_tokens: Dict[str, TokenUsage] = {}
# tool config keys called per node
self._node_tool_calls: Dict[str, List[str]] = {}
# start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes
self._node_start_ns: Dict[UUID, int] = {}
# accumulated duration (ms) per node
self._node_duration_ms: Dict[str, int] = {}
# execution path in order (deduplicated)
self._path: List[str] = []
self._path_set: Set[str] = set()
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def path(self) -> List[str]:
"""Execution path through the graph in order."""
return list(self._path)
@property
def node_tokens(self) -> Dict[str, TokenUsage]:
"""Accumulated token usage per node key."""
return dict(self._node_tokens)
@property
def node_tool_calls(self) -> Dict[str, List[str]]:
"""Tool config keys called per node key."""
return {k: list(v) for k, v in self._node_tool_calls.items()}
@property
def node_durations_ms(self) -> Dict[str, int]:
"""Accumulated duration in milliseconds per node key."""
return dict(self._node_duration_ms)
# ------------------------------------------------------------------
# Callbacks
# ------------------------------------------------------------------
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Record start of a chain run; attribute to the matching agent node."""
if name is None:
return
if name in self._node_keys:
self._run_to_node[run_id] = name
self._node_start_ns[run_id] = time.perf_counter_ns()
if name not in self._path_set:
self._path.append(name)
self._path_set.add(name)
elif name.endswith('__tools'):
stripped = name[: -len('__tools')]
if stripped in self._node_keys:
# Attribute tool events to the owning agent node
self._run_to_node[run_id] = stripped
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Record end of a chain run and accumulate elapsed duration."""
node_key = self._run_to_node.get(run_id)
if node_key is None:
return
start_ns = self._node_start_ns.pop(run_id, None)
if start_ns is not None:
elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
self._node_duration_ms[node_key] = (
self._node_duration_ms.get(node_key, 0) + elapsed_ms
)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Accumulate token usage for the node that owns this LLM call."""
if parent_run_id is None:
return
node_key = self._run_to_node.get(parent_run_id)
if node_key is None:
return
try:
gen = response.generations[0][0]
except (IndexError, TypeError):
return
if not isinstance(gen, ChatGeneration):
return
message = gen.message
usage = get_ai_usage_from_response(message)
if usage is None:
return
existing = self._node_tokens.get(node_key)
if existing is None:
self._node_tokens[node_key] = usage
else:
self._node_tokens[node_key] = TokenUsage(
total=existing.total + usage.total,
input=existing.input + usage.input,
output=existing.output + usage.output,
)
def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Record a tool invocation for the owning agent node."""
if parent_run_id is None or name is None:
return
node_key = self._run_to_node.get(parent_run_id)
if node_key is None:
return
config_key = self._fn_name_to_config_key.get(name)
if config_key is None:
# Tool is not a registered functional tool (e.g. a handoff tool) — skip tracking.
return
if node_key not in self._node_tool_calls:
self._node_tool_calls[node_key] = []
self._node_tool_calls[node_key].append(config_key)
# ------------------------------------------------------------------
# Flush
# ------------------------------------------------------------------
async def flush(self, graph: AgentGraphDefinition, eval_tasks=None) -> None:
"""
Emit all collected per-node metrics to the LaunchDarkly trackers.
Call this once after the graph run completes.
:param graph: The AgentGraphDefinition whose nodes hold the LD config trackers.
:param eval_tasks: Optional dict mapping node key to a list of awaitables that
return judge evaluation results. Multiple tasks arise when a node is visited
more than once (e.g. in a graph with cycles).
"""
node_trackers: Dict[str, Any] = {}
for node_key in self._path:
if node_key in node_trackers:
continue
node = graph.get_node(node_key)
if not node:
continue
config_tracker = node.get_config().create_tracker()
if not config_tracker:
continue
node_trackers[node_key] = config_tracker
usage = self._node_tokens.get(node_key)
if usage:
config_tracker.track_tokens(usage)
duration = self._node_duration_ms.get(node_key)
if duration is not None:
config_tracker.track_duration(duration)
config_tracker.track_success()
for tool_key in self._node_tool_calls.get(node_key, []):
config_tracker.track_tool_call(tool_key)
if not eval_tasks:
continue
for eval_task in eval_tasks.get(node_key, []):
results = await eval_task
for r in results:
if r.success:
config_tracker.track_judge_result(r)