-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlanggraph_callback_handler.py
More file actions
180 lines (159 loc) · 6.21 KB
/
langgraph_callback_handler.py
File metadata and controls
180 lines (159 loc) · 6.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import time
from typing import Any, Dict, List, Optional, Set
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import ChatGeneration, LLMResult
from ldai.providers.types import LDAIMetrics
from ldai.tracker import TokenUsage
from ldai_langchain.langchain_helper import get_ai_usage_from_response
class LDMetricsCallbackHandler(BaseCallbackHandler):
"""
CAUTION:
This feature is experimental and should NOT be considered ready for production use.
It may change or be removed without notice and is not subject to backwards
compatibility guarantees.
LangChain callback handler that collects per-node metrics during a LangGraph run.
Records token usage, tool calls, and duration for each agent node in the graph.
Each node's :class:`~ldai.providers.types.LDAIMetrics` is built incrementally
as callbacks fire. Access the ``node_metrics`` property after the run completes
to retrieve the accumulated per-node metrics.
"""
def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
"""
Initialize the handler.
:param node_keys: Set of LangGraph node keys that represent agent nodes
(excludes ``__tools`` suffix nodes).
:param fn_name_to_config_key: Mapping from tool function ``__name__`` to
the LD config key for that tool (e.g. ``'fetch_weather'`` -> ``'get_weather_open_meteo'``).
"""
super().__init__()
self._node_keys = node_keys
self._fn_name_to_config_key = fn_name_to_config_key
# run_id -> node_key for active chain runs
self._run_to_node: Dict[UUID, str] = {}
# start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes
self._node_start_ns: Dict[UUID, int] = {}
# per-node metrics, built incrementally as callbacks fire
self._node_metrics: Dict[str, LDAIMetrics] = {}
# execution path in order (deduplicated)
self._path: List[str] = []
self._path_set: Set[str] = set()
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def path(self) -> List[str]:
"""Execution path through the graph in order."""
return list(self._path)
@property
def node_metrics(self) -> Dict[str, LDAIMetrics]:
"""Per-node metrics keyed by node key."""
return dict(self._node_metrics)
# ------------------------------------------------------------------
# Callbacks
# ------------------------------------------------------------------
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Record start of a chain run; attribute to the matching agent node."""
if name is None:
return
if name in self._node_keys:
self._run_to_node[run_id] = name
self._node_start_ns[run_id] = time.perf_counter_ns()
if name not in self._path_set:
self._path.append(name)
self._path_set.add(name)
self._node_metrics[name] = LDAIMetrics(success=False)
elif name.endswith('__tools'):
stripped = name[: -len('__tools')]
if stripped in self._node_keys:
self._run_to_node[run_id] = stripped
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Record end of a chain run and accumulate elapsed duration."""
node_key = self._run_to_node.get(run_id)
if node_key is None:
return
start_ns = self._node_start_ns.pop(run_id, None)
if start_ns is not None:
elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
metrics = self._node_metrics.get(node_key)
if metrics is not None:
metrics.success = True
metrics.duration_ms = (metrics.duration_ms or 0) + elapsed_ms
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Accumulate token usage for the node that owns this LLM call."""
if parent_run_id is None:
return
node_key = self._run_to_node.get(parent_run_id)
if node_key is None:
return
try:
gen = response.generations[0][0]
except (IndexError, TypeError):
return
if not isinstance(gen, ChatGeneration):
return
message = gen.message
usage = get_ai_usage_from_response(message)
if usage is None:
return
metrics = self._node_metrics.get(node_key)
if metrics is None:
return
existing = metrics.tokens
if existing is None:
metrics.tokens = usage
else:
metrics.tokens = TokenUsage(
total=existing.total + usage.total,
input=existing.input + usage.input,
output=existing.output + usage.output,
)
def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Record a tool invocation for the owning agent node."""
if parent_run_id is None or name is None:
return
node_key = self._run_to_node.get(parent_run_id)
if node_key is None:
return
config_key = self._fn_name_to_config_key.get(name)
if config_key is None:
return
metrics = self._node_metrics.get(node_key)
if metrics is None:
return
if metrics.tool_calls is None:
metrics.tool_calls = [config_key]
else:
metrics.tool_calls.append(config_key)