-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlanggraph_agent_graph_runner.py
More file actions
359 lines (304 loc) · 15.1 KB
/
langgraph_agent_graph_runner.py
File metadata and controls
359 lines (304 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
import asyncio
import time
from contextvars import ContextVar
from typing import Annotated, Any, Dict, List, Set, Tuple
from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
from ldai.providers.types import LDAIMetrics
from ldai_langchain.langchain_helper import (
build_structured_tools,
create_langchain_model,
extract_last_message_content,
sum_token_usage_from_messages,
)
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar.
_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks')
def _make_handoff_tool(child_key: str, description: str) -> Any:
"""
Create a tool that transfers control to ``child_key``.
Uses the ``@tool`` decorator with ``InjectedState`` + ``InjectedToolCallId``
so LangGraph's ToolNode handles the ``Command`` return value correctly.
The tool explicitly creates a ToolMessage in ``Command.update`` to satisfy
the LangChain/OpenAI message-chain contract.
"""
from typing import Annotated as _Annotated
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
tool_name = f"transfer_to_{child_key.replace('-', '_')}"
@tool(tool_name, description=description)
def handoff(
state: _Annotated[Any, InjectedState], # noqa: ARG001
tool_call_id: _Annotated[str, InjectedToolCallId],
) -> Command:
tool_message = ToolMessage(
content=f'Transferred to {child_key}',
name=tool_name,
tool_call_id=tool_call_id,
)
return Command(goto=child_key, update={'messages': [tool_message]})
return handoff
class LangGraphAgentGraphRunner(AgentGraphRunner):
"""
CAUTION:
This feature is experimental and should NOT be considered ready for production use.
It may change or be removed without notice and is not subject to backwards
compatibility guarantees.
AgentGraphRunner implementation for LangGraph.
Compiles and runs the agent graph with LangGraph and automatically records
graph- and node-level AI metric data to the LaunchDarkly trackers on the
graph definition and each node.
Requires ``langgraph`` to be installed.
"""
def __init__(
self,
graph: AgentGraphDefinition,
tools: ToolRegistry,
):
"""
Initialize the runner.
:param graph: The AgentGraphDefinition to execute
:param tools: Registry mapping tool names to callables (langchain-compatible)
"""
self._graph = graph
self._tools = tools
self._compiled: Any = None
self._fn_name_to_config_key: Dict[str, str] = {}
self._node_keys: Set[str] = set()
def _ensure_compiled(self) -> None:
"""Build and cache the compiled graph if not already done."""
if self._compiled is None:
compiled, fn_name_to_config_key, node_keys = self._build_graph()
self._compiled = compiled
self._fn_name_to_config_key = fn_name_to_config_key
self._node_keys = node_keys
def _build_graph(self) -> Tuple[Any, Dict[str, str], Set[str]]:
"""
Build and compile the LangGraph StateGraph from the AgentGraphDefinition.
:return: Tuple of (compiled_graph, fn_name_to_config_key, node_keys) where
fn_name_to_config_key maps tool function __name__ to LD config key, and
node_keys is the set of all agent node keys in the graph.
"""
from langchain_core.messages import SystemMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict
class WorkflowState(TypedDict):
messages: Annotated[List[Any], add_messages]
agent_builder: StateGraph = StateGraph(WorkflowState)
root_node = self._graph.root()
root_key = root_node.get_key() if root_node else None
tools_ref = self._tools
graph_structure: List[str] = []
fn_name_to_config_key: Dict[str, str] = {}
node_keys: Set[str] = set()
def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
node_config = node.get_config()
node_key = node.get_key()
node_keys.add(node_key)
instructions = node_config.instructions if hasattr(node_config, 'instructions') else None
outgoing_edges = node.get_edges()
lc_model = None
tool_fns: list = []
if node_config.model:
# We send an empty tool registry to avoid binding tools to the model.
lc_model = create_langchain_model(node_config)
tool_fns = build_structured_tools(node_config, tools_ref)
# Map tool name -> LD config key for callback attribution.
# build_structured_tools returns StructuredTool instances with tool.name set
# to the LD config key, so tool.name IS the config key.
for tool in tool_fns:
tool_name = getattr(tool, 'name', None)
if tool_name:
fn_name_to_config_key[tool_name] = tool_name
# For nodes with multiple children, create a handoff tool per child so the
# LLM decides which agent to route to. Uses Command(goto=child_key) so
# LangGraph routes to the target without looping back here.
handoff_fns: list = []
if lc_model and len(outgoing_edges) > 1:
for edge in outgoing_edges:
child_node = self._graph.get_node(edge.target_config)
description = (
(edge.handoff or {}).get('description')
or (
child_node.get_config().instructions[:120]
if child_node and child_node.get_config().instructions
else None
)
or f"Transfer control to {edge.target_config}"
)
handoff_fns.append(_make_handoff_tool(edge.target_config, description))
all_tools = tool_fns + handoff_fns
model: Any
if lc_model and all_tools:
# When handoff tools are present, disable parallel tool calls so the LLM
# picks exactly one destination rather than routing to multiple children.
bind_kwargs: Dict[str, Any] = {'parallel_tool_calls': False} if handoff_fns else {}
model = lc_model.bind_tools(all_tools, **bind_kwargs)
else:
model = lc_model
def make_node_fn(bound_model: Any, node_instructions: Any, nk: str):
async def invoke(state: WorkflowState) -> dict:
if not bound_model:
return {'messages': []}
msgs = list(state['messages'])
if node_instructions:
msgs = [SystemMessage(content=node_instructions)] + msgs
response = await bound_model.ainvoke(msgs)
node_obj = self._graph.get_node(nk)
if node_obj is not None:
input_text = '\r\n'.join(
m.content if isinstance(m.content, str) else str(m.content)
for m in msgs
) if msgs else ''
output_text = (
response.content if hasattr(response, 'content') else str(response)
)
task = node_obj.get_config().evaluator.evaluate(input_text, output_text)
_run_eval_tasks.get({}).setdefault(nk, []).append(task)
return {'messages': [response]}
invoke.__name__ = nk
return invoke
invoke_fn = make_node_fn(model, instructions, node_key)
agent_builder.add_node(node_key, invoke_fn)
if node_key == root_key:
agent_builder.add_edge(START, node_key)
# Collect node info for graph structure log
tool_names = [str(getattr(t, 'name', None) or getattr(t, '__name__', t)) for t in tool_fns]
edge_targets = [e.target_config for e in outgoing_edges]
node_desc = node_key
if tool_names:
node_desc += f"[tools:{','.join(tool_names)}]"
if handoff_fns:
node_desc += f"[handoff:{','.join(edge_targets)}]"
elif edge_targets:
node_desc += f"→{','.join(edge_targets)}"
else:
node_desc += "(terminal)"
graph_structure.append(node_desc)
if all_tools:
tools_node_key = f"{node_key}__tools"
agent_builder.add_node(tools_node_key, ToolNode(all_tools))
if not handoff_fns:
# No handoff tools: standard loop-back after tool execution.
after_loop = outgoing_edges[0].target_config if outgoing_edges else END
if len(outgoing_edges) > 1:
log.warning(
f"Node '{node_key}' has {len(outgoing_edges)} outgoing edges but no handoff "
"tools; only the first edge will be used after the tool loop. "
"Use handoff tools for multi-child routing."
)
agent_builder.add_edge(tools_node_key, node_key)
agent_builder.add_conditional_edges(
node_key,
tools_condition,
{"tools": tools_node_key, END: after_loop},
)
elif not tool_fns:
# Only handoff tools: no loop-back needed.
# Command(goto=child_key) handles routing to the target.
agent_builder.add_conditional_edges(
node_key,
tools_condition,
{"tools": tools_node_key, END: END},
)
else:
# Both functional and handoff tools. A static loop-back edge would
# fan-out with Command(goto=child_key) from handoff tools, so use a
# conditional edge that only loops back for functional tool results.
handoff_names_set = frozenset(getattr(t, 'name', '') for t in handoff_fns)
def make_after_tools_router(parent_key: str, ht_names: frozenset):
def route(state: WorkflowState) -> str:
msgs = state['messages']
if msgs:
last = msgs[-1]
if hasattr(last, 'name') and last.name in ht_names:
return END
return parent_key
return route
agent_builder.add_conditional_edges(
tools_node_key,
make_after_tools_router(node_key, handoff_names_set),
{node_key: node_key, END: END},
)
agent_builder.add_conditional_edges(
node_key,
tools_condition,
{"tools": tools_node_key, END: END},
)
else:
if node.is_terminal():
agent_builder.add_edge(node_key, END)
for edge in outgoing_edges:
agent_builder.add_edge(node_key, edge.target_config)
return None
self._graph.traverse(fn=handle_traversal)
graph_key_str = self._graph._agent_graph.key or 'unknown'
log.debug(
f"LangGraphAgentGraphRunner: graph='{graph_key_str}', root='{root_key}', "
f"structure: {' | '.join(graph_structure)}"
)
compiled = agent_builder.compile()
return compiled, fn_name_to_config_key, node_keys
async def run(self, input: Any) -> AgentGraphResult:
"""
Run the agent graph with the given input.
Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
it, and invokes it. Uses a LangChain callback handler to collect
per-node metrics, then flushes them to LaunchDarkly trackers.
:param input: The string prompt to send to the agent graph
:return: AgentGraphResult with the final output and metrics
"""
pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
token = _run_eval_tasks.set(pending_eval_tasks)
tracker = self._graph.create_tracker() if self._graph.create_tracker is not None else None
start_ns = time.perf_counter_ns()
try:
from langchain_core.messages import HumanMessage
self._ensure_compiled()
handler = LDMetricsCallbackHandler(self._node_keys, self._fn_name_to_config_key)
result = await self._compiled.ainvoke( # type: ignore[call-overload]
{'messages': [HumanMessage(content=str(input))]},
config={'callbacks': [handler], 'recursion_limit': 25},
)
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
messages = result.get('messages', [])
output = extract_last_message_content(messages)
# Flush per-node metrics to LD trackers
await handler.flush(self._graph, pending_eval_tasks)
# Graph-level metrics
if tracker:
tracker.track_path(handler.path)
tracker.track_duration(duration)
tracker.track_invocation_success()
tracker.track_total_tokens(sum_token_usage_from_messages(messages))
return AgentGraphResult(
output=output,
raw=result,
metrics=LDAIMetrics(success=True),
)
except Exception as exc:
if isinstance(exc, ImportError):
log.warning(
"langgraph is required for LangGraphAgentGraphRunner. "
"Install it with: pip install langgraph"
)
else:
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
if tracker:
tracker.track_duration(duration)
tracker.track_invocation_failure()
return AgentGraphResult(
output='',
raw=None,
metrics=LDAIMetrics(success=False),
)
finally:
_run_eval_tasks.reset(token)