-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlanggraph_agent_graph_runner.py
More file actions
414 lines (352 loc) · 17 KB
/
langgraph_agent_graph_runner.py
File metadata and controls
414 lines (352 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
import time
from typing import Annotated, Any, Dict, FrozenSet, List, Set, Tuple
from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphRunner, ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult, AIGraphMetrics, EvalRequest
from ldai_langchain.langchain_helper import (
build_structured_tools,
create_langchain_model,
extract_last_message_content,
sum_token_usage_from_messages,
)
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
def _message_content_to_str(content: Any) -> str:
"""Normalize a LangChain message ``content`` (string or list of parts) to a string."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get('text')
if isinstance(text, str):
parts.append(text)
return '\r\n'.join(parts)
return str(content)
def _maybe_record_eval_request(
eval_requests: List[EvalRequest],
node_key: str,
msgs: List[Any],
response: Any,
handoff_tool_names: FrozenSet[str],
) -> None:
"""
Append an :class:`EvalRequest` to ``eval_requests`` when ``response``
represents the agent's final output for this activation.
Skips emission when the response only requests further tool calls (still
working in a tool loop) or when there is no content to evaluate. Tool
calls limited to handoff tools are treated as the agent terminating with
a transfer, so the response is still emitted.
"""
tool_calls = getattr(response, 'tool_calls', None) or []
if tool_calls:
# If every tool call is a handoff, the agent is terminating with a
# transfer; otherwise it is still working through a tool loop.
for tc in tool_calls:
name = tc.get('name') if isinstance(tc, dict) else getattr(tc, 'name', None)
if name not in handoff_tool_names:
return
response_content = getattr(response, 'content', response)
output_text = _message_content_to_str(response_content)
if not output_text or not output_text.strip():
return
input_text = '\r\n'.join(
_message_content_to_str(getattr(m, 'content', m)) for m in msgs
) if msgs else ''
eval_requests.append(
EvalRequest(node_key=node_key, input=input_text, output=output_text)
)
def _make_handoff_tool(child_key: str, description: str) -> Any:
"""
Create a tool that transfers control to ``child_key``.
Uses the ``@tool`` decorator with ``InjectedState`` + ``InjectedToolCallId``
so LangGraph's ToolNode handles the ``Command`` return value correctly.
The tool explicitly creates a ToolMessage in ``Command.update`` to satisfy
the LangChain/OpenAI message-chain contract.
"""
from typing import Annotated as _Annotated
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.tools.base import InjectedToolCallId
from langgraph.prebuilt import InjectedState
from langgraph.types import Command
tool_name = f"transfer_to_{child_key.replace('-', '_')}"
@tool(tool_name, description=description)
def handoff(
state: _Annotated[Any, InjectedState], # noqa: ARG001
tool_call_id: _Annotated[str, InjectedToolCallId],
) -> Command:
tool_message = ToolMessage(
content=f'Transferred to {child_key}',
name=tool_name,
tool_call_id=tool_call_id,
)
return Command(goto=child_key, update={'messages': [tool_message]})
return handoff
class LangGraphAgentGraphRunner(AgentGraphRunner):
"""
CAUTION:
This feature is experimental and should NOT be considered ready for production use.
It may change or be removed without notice and is not subject to backwards
compatibility guarantees.
AgentGraphRunner implementation for LangGraph.
Compiles and runs the agent graph with LangGraph and collects graph- and
node-level metrics via a LangChain callback handler. Tracking events are
emitted by the managed layer (:class:`~ldai.ManagedAgentGraph`) from the
returned :class:`~ldai.providers.types.AgentGraphRunnerResult`.
Requires ``langgraph`` to be installed.
"""
def __init__(
self,
graph: AgentGraphDefinition,
tools: ToolRegistry,
):
"""
Initialize the runner.
:param graph: The AgentGraphDefinition to execute
:param tools: Registry mapping tool names to callables (langchain-compatible)
"""
self._graph = graph
self._tools = tools
def _build_graph(
self, eval_requests: List[EvalRequest]
) -> Tuple[Any, Dict[str, str], Set[str]]:
"""
Build and compile the LangGraph StateGraph from the AgentGraphDefinition.
:return: Tuple of (compiled_graph, fn_name_to_config_key, node_keys) where
fn_name_to_config_key maps tool function __name__ to LD config key, and
node_keys is the set of all agent node keys in the graph.
"""
from langchain_core.messages import SystemMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from typing_extensions import TypedDict
class WorkflowState(TypedDict):
messages: Annotated[List[Any], add_messages]
agent_builder: StateGraph = StateGraph(WorkflowState)
root_node = self._graph.root()
root_key = root_node.get_key() if root_node else None
tools_ref = self._tools
graph_structure: List[str] = []
fn_name_to_config_key: Dict[str, str] = {}
node_keys: Set[str] = set()
def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
node_config = node.get_config()
node_key = node.get_key()
node_keys.add(node_key)
instructions = node_config.instructions if hasattr(node_config, 'instructions') else None
outgoing_edges = node.get_edges()
lc_model = None
tool_fns: list = []
if node_config.model:
# We send an empty tool registry to avoid binding tools to the model.
lc_model = create_langchain_model(node_config)
tool_fns = build_structured_tools(node_config, tools_ref)
# Map tool name -> LD config key for callback attribution.
# build_structured_tools returns StructuredTool instances with tool.name set
# to the LD config key, so tool.name IS the config key.
for tool in tool_fns:
tool_name = getattr(tool, 'name', None)
if tool_name:
fn_name_to_config_key[tool_name] = tool_name
# For nodes with multiple children, create a handoff tool per child so the
# LLM decides which agent to route to. Uses Command(goto=child_key) so
# LangGraph routes to the target without looping back here.
handoff_fns: list = []
if lc_model and len(outgoing_edges) > 1:
for edge in outgoing_edges:
child_node = self._graph.get_node(edge.target_config)
description = (
(edge.handoff or {}).get('description')
or (
child_node.get_config().instructions[:120]
if child_node and child_node.get_config().instructions
else None
)
or f"Transfer control to {edge.target_config}"
)
handoff_fns.append(_make_handoff_tool(edge.target_config, description))
all_tools = tool_fns + handoff_fns
model: Any
if lc_model and all_tools:
# When handoff tools are present, disable parallel tool calls so the LLM
# picks exactly one destination rather than routing to multiple children.
bind_kwargs: Dict[str, Any] = {'parallel_tool_calls': False} if handoff_fns else {}
model = lc_model.bind_tools(all_tools, **bind_kwargs)
else:
model = lc_model
# Names of the handoff tools attached to this node. Tool calls
# against these are control-flow signals, not the agent doing work,
# so they must not block emission of an EvalRequest.
handoff_tool_names: FrozenSet[str] = frozenset(
getattr(t, 'name', '') for t in handoff_fns
)
# Whether this node has at least one judge configured. Nodes without
# judges contribute zero EvalRequest entries.
jc = getattr(node_config, 'judge_configuration', None)
node_has_judges = bool(jc is not None and getattr(jc, 'judges', None))
def make_node_fn(
bound_model: Any,
node_instructions: Any,
nk: str,
ht_names: FrozenSet[str],
emit_eval: bool,
):
async def invoke(state: WorkflowState) -> dict:
if not bound_model:
return {'messages': []}
msgs = list(state['messages'])
if node_instructions:
msgs = [SystemMessage(content=node_instructions)] + msgs
response = await bound_model.ainvoke(msgs)
if emit_eval:
_maybe_record_eval_request(
eval_requests, nk, msgs, response, ht_names
)
return {'messages': [response]}
invoke.__name__ = nk
return invoke
invoke_fn = make_node_fn(
model, instructions, node_key, handoff_tool_names, node_has_judges
)
agent_builder.add_node(node_key, invoke_fn)
if node_key == root_key:
agent_builder.add_edge(START, node_key)
# Collect node info for graph structure log
tool_names = [str(getattr(t, 'name', None) or getattr(t, '__name__', t)) for t in tool_fns]
edge_targets = [e.target_config for e in outgoing_edges]
node_desc = node_key
if tool_names:
node_desc += f"[tools:{','.join(tool_names)}]"
if handoff_fns:
node_desc += f"[handoff:{','.join(edge_targets)}]"
elif edge_targets:
node_desc += f"→{','.join(edge_targets)}"
else:
node_desc += "(terminal)"
graph_structure.append(node_desc)
if all_tools:
tools_node_key = f"{node_key}__tools"
agent_builder.add_node(tools_node_key, ToolNode(all_tools))
if not handoff_fns:
# No handoff tools: standard loop-back after tool execution.
after_loop = outgoing_edges[0].target_config if outgoing_edges else END
if len(outgoing_edges) > 1:
log.warning(
f"Node '{node_key}' has {len(outgoing_edges)} outgoing edges but no handoff "
"tools; only the first edge will be used after the tool loop. "
"Use handoff tools for multi-child routing."
)
agent_builder.add_edge(tools_node_key, node_key)
agent_builder.add_conditional_edges(
node_key,
tools_condition,
{"tools": tools_node_key, END: after_loop},
)
elif not tool_fns:
# Only handoff tools: no loop-back needed.
# Command(goto=child_key) handles routing to the target.
agent_builder.add_conditional_edges(
node_key,
tools_condition,
{"tools": tools_node_key, END: END},
)
else:
# Both functional and handoff tools. A static loop-back edge would
# fan-out with Command(goto=child_key) from handoff tools, so use a
# conditional edge that only loops back for functional tool results.
handoff_names_set = frozenset(getattr(t, 'name', '') for t in handoff_fns)
def make_after_tools_router(parent_key: str, ht_names: frozenset):
def route(state: WorkflowState) -> str:
msgs = state['messages']
if msgs:
last = msgs[-1]
if hasattr(last, 'name') and last.name in ht_names:
return END
return parent_key
return route
agent_builder.add_conditional_edges(
tools_node_key,
make_after_tools_router(node_key, handoff_names_set),
{node_key: node_key, END: END},
)
agent_builder.add_conditional_edges(
node_key,
tools_condition,
{"tools": tools_node_key, END: END},
)
else:
if node.is_terminal():
agent_builder.add_edge(node_key, END)
for edge in outgoing_edges:
agent_builder.add_edge(node_key, edge.target_config)
return None
self._graph.traverse(fn=handle_traversal)
graph_key_str = self._graph._agent_graph.key or 'unknown'
log.debug(
f"LangGraphAgentGraphRunner: graph='{graph_key_str}', root='{root_key}', "
f"structure: {' | '.join(graph_structure)}"
)
compiled = agent_builder.compile()
return compiled, fn_name_to_config_key, node_keys
async def run(self, input: str) -> AgentGraphRunnerResult:
"""
Run the agent graph with the given input.
Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
it, and invokes it. Uses a LangChain callback handler to collect
per-node metrics. Graph-level tracking events are emitted by the
managed layer from the returned AIGraphMetrics.
:param input: The string prompt to send to the agent graph
:return: AgentGraphRunnerResult with the final content and AIGraphMetrics
"""
start_ns = time.perf_counter_ns()
# Per-run state — kept local so concurrent run() calls do not share it.
eval_requests: List[EvalRequest] = []
try:
from langchain_core.messages import HumanMessage
compiled, fn_name_to_config_key, node_keys = self._build_graph(eval_requests)
handler = LDMetricsCallbackHandler(node_keys, fn_name_to_config_key)
result = await compiled.ainvoke( # type: ignore[call-overload]
{'messages': [HumanMessage(content=input)]},
config={'callbacks': [handler], 'recursion_limit': 25},
)
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
messages = result.get('messages', [])
output = extract_last_message_content(messages)
total_usage = sum_token_usage_from_messages(messages)
node_metrics = handler.node_metrics
return AgentGraphRunnerResult(
content=output,
raw=result,
metrics=AIGraphMetrics(
success=True,
path=handler.path,
duration_ms=duration_ms,
tokens=total_usage if (total_usage is not None and total_usage.total > 0) else None,
node_metrics=node_metrics,
),
eval_requests=eval_requests if eval_requests else None,
)
except Exception as exc:
if isinstance(exc, ImportError):
log.warning(
"langgraph is required for LangGraphAgentGraphRunner. "
"Install it with: pip install langgraph"
)
else:
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
return AgentGraphRunnerResult(
content='',
raw=None,
metrics=AIGraphMetrics(
success=False,
duration_ms=duration_ms,
),
eval_requests=eval_requests if eval_requests else None,
)