Skip to content

Commit efa8e00

Browse files
authored
feat: Migrate LangGraph runner to AgentGraphRunnerResult; clean up legacy shape detection (#156)
1 parent 388b7af commit efa8e00

6 files changed

Lines changed: 269 additions & 493 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_agent_graph_runner.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
22

3-
import asyncio
43
import time
5-
from contextvars import ContextVar
64
from typing import Annotated, Any, Dict, List, Set, Tuple
75

86
from ldai import log
97
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
10-
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
11-
from ldai.providers.types import LDAIMetrics
8+
from ldai.providers import AgentGraphRunner, ToolRegistry
9+
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics
1210

1311
from ldai_langchain.langchain_helper import (
1412
build_structured_tools,
@@ -18,9 +16,6 @@
1816
)
1917
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler
2018

21-
# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar.
22-
_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks')
23-
2419

2520
def _make_handoff_tool(child_key: str, description: str) -> Any:
2621
"""
@@ -65,9 +60,10 @@ class LangGraphAgentGraphRunner(AgentGraphRunner):
6560
6661
AgentGraphRunner implementation for LangGraph.
6762
68-
Compiles and runs the agent graph with LangGraph and automatically records
69-
graph- and node-level AI metric data to the LaunchDarkly trackers on the
70-
graph definition and each node.
63+
Compiles and runs the agent graph with LangGraph and collects graph- and
64+
node-level metrics via a LangChain callback handler. Tracking events are
65+
emitted by the managed layer (:class:`~ldai.ManagedAgentGraph`) from the
66+
returned :class:`~ldai.providers.types.AgentGraphRunnerResult`.
7167
7268
Requires ``langgraph`` to be installed.
7369
"""
@@ -181,26 +177,6 @@ async def invoke(state: WorkflowState) -> dict:
181177
if node_instructions:
182178
msgs = [SystemMessage(content=node_instructions)] + msgs
183179
response = await bound_model.ainvoke(msgs)
184-
185-
node_obj = self._graph.get_node(nk)
186-
if node_obj is not None:
187-
input_text = '\r\n'.join(
188-
m.content if isinstance(m.content, str) else str(m.content)
189-
for m in msgs
190-
) if msgs else ''
191-
output_text = (
192-
response.content if hasattr(response, 'content') else str(response)
193-
)
194-
task = node_obj.get_config().evaluator.evaluate(input_text, output_text)
195-
run_tasks = _run_eval_tasks.get(None)
196-
if run_tasks is not None:
197-
run_tasks.setdefault(nk, []).append(task)
198-
else:
199-
log.warning(
200-
f"LangGraphAgentGraphRunner: eval task for node '{nk}' "
201-
"has no run context; judge results will not be tracked"
202-
)
203-
204180
return {'messages': [response]}
205181

206182
invoke.__name__ = nk
@@ -298,20 +274,18 @@ def route(state: WorkflowState) -> str:
298274
compiled = agent_builder.compile()
299275
return compiled, fn_name_to_config_key, node_keys
300276

301-
async def run(self, input: Any) -> AgentGraphResult:
277+
async def run(self, input: Any) -> AgentGraphRunnerResult:
302278
"""
303279
Run the agent graph with the given input.
304280
305281
Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
306282
it, and invokes it. Uses a LangChain callback handler to collect
307-
per-node metrics, then flushes them to LaunchDarkly trackers.
283+
per-node metrics. Graph-level tracking events are emitted by the
284+
managed layer from the returned GraphMetrics.
308285
309286
:param input: The string prompt to send to the agent graph
310-
:return: AgentGraphResult with the final output and metrics
287+
:return: AgentGraphRunnerResult with the final content and GraphMetrics
311288
"""
312-
pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
313-
token = _run_eval_tasks.set(pending_eval_tasks)
314-
tracker = self._graph.create_tracker()
315289
start_ns = time.perf_counter_ns()
316290

317291
try:
@@ -325,24 +299,23 @@ async def run(self, input: Any) -> AgentGraphResult:
325299
config={'callbacks': [handler], 'recursion_limit': 25},
326300
)
327301

328-
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
302+
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
329303
messages = result.get('messages', [])
330304
output = extract_last_message_content(messages)
305+
total_usage = sum_token_usage_from_messages(messages)
331306

332-
# Flush per-node metrics to LD trackers; eval results are tracked
333-
# internally and intentionally not exposed on AgentGraphResult here
334-
# — judge dispatch is the managed layer's responsibility.
335-
await handler.flush(self._graph, pending_eval_tasks)
336-
337-
tracker.track_path(handler.path)
338-
tracker.track_duration(duration)
339-
tracker.track_invocation_success()
340-
tracker.track_total_tokens(sum_token_usage_from_messages(messages))
307+
node_metrics = handler.node_metrics
341308

342-
return AgentGraphResult(
343-
output=output,
309+
return AgentGraphRunnerResult(
310+
content=output,
344311
raw=result,
345-
metrics=LDAIMetrics(success=True),
312+
metrics=GraphMetrics(
313+
success=True,
314+
path=handler.path,
315+
duration_ms=duration_ms,
316+
usage=total_usage if (total_usage is not None and total_usage.total > 0) else None,
317+
node_metrics=node_metrics,
318+
),
346319
)
347320

348321
except Exception as exc:
@@ -353,13 +326,12 @@ async def run(self, input: Any) -> AgentGraphResult:
353326
)
354327
else:
355328
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
356-
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
357-
tracker.track_duration(duration)
358-
tracker.track_invocation_failure()
359-
return AgentGraphResult(
360-
output='',
329+
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
330+
return AgentGraphRunnerResult(
331+
content='',
361332
raw=None,
362-
metrics=LDAIMetrics(success=False),
333+
metrics=GraphMetrics(
334+
success=False,
335+
duration_ms=duration_ms,
336+
),
363337
)
364-
finally:
365-
_run_eval_tasks.reset(token)

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langgraph_callback_handler.py

Lines changed: 28 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
from langchain_core.callbacks import BaseCallbackHandler
66
from langchain_core.outputs import ChatGeneration, LLMResult
7-
from ldai.agent_graph import AgentGraphDefinition
8-
from ldai.providers.types import JudgeResult
7+
from ldai.providers.types import LDAIMetrics
98
from ldai.tracker import TokenUsage
109

1110
from ldai_langchain.langchain_helper import get_ai_usage_from_response
@@ -20,8 +19,10 @@ class LDMetricsCallbackHandler(BaseCallbackHandler):
2019
2120
LangChain callback handler that collects per-node metrics during a LangGraph run.
2221
23-
Records token usage, tool calls, and duration for each agent node in the graph,
24-
then flushes them to LaunchDarkly trackers after the run completes via ``flush()``.
22+
Records token usage, tool calls, and duration for each agent node in the graph.
23+
Each node's :class:`~ldai.providers.types.LDAIMetrics` is built incrementally
24+
as callbacks fire. Access the ``node_metrics`` property after the run completes
25+
to retrieve the accumulated per-node metrics.
2526
"""
2627

2728
def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
@@ -39,14 +40,10 @@ def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
3940

4041
# run_id -> node_key for active chain runs
4142
self._run_to_node: Dict[UUID, str] = {}
42-
# accumulated token usage per node
43-
self._node_tokens: Dict[str, TokenUsage] = {}
44-
# tool config keys called per node
45-
self._node_tool_calls: Dict[str, List[str]] = {}
4643
# start time (ns) per active run_id — keyed by run_id to handle re-entrant nodes
4744
self._node_start_ns: Dict[UUID, int] = {}
48-
# accumulated duration (ms) per node
49-
self._node_duration_ms: Dict[str, int] = {}
45+
# per-node metrics, built incrementally as callbacks fire
46+
self._node_metrics: Dict[str, LDAIMetrics] = {}
5047
# execution path in order (deduplicated)
5148
self._path: List[str] = []
5249
self._path_set: Set[str] = set()
@@ -61,19 +58,9 @@ def path(self) -> List[str]:
6158
return list(self._path)
6259

6360
@property
64-
def node_tokens(self) -> Dict[str, TokenUsage]:
65-
"""Accumulated token usage per node key."""
66-
return dict(self._node_tokens)
67-
68-
@property
69-
def node_tool_calls(self) -> Dict[str, List[str]]:
70-
"""Tool config keys called per node key."""
71-
return {k: list(v) for k, v in self._node_tool_calls.items()}
72-
73-
@property
74-
def node_durations_ms(self) -> Dict[str, int]:
75-
"""Accumulated duration in milliseconds per node key."""
76-
return dict(self._node_duration_ms)
61+
def node_metrics(self) -> Dict[str, LDAIMetrics]:
62+
"""Per-node metrics keyed by node key."""
63+
return dict(self._node_metrics)
7764

7865
# ------------------------------------------------------------------
7966
# Callbacks
@@ -101,10 +88,10 @@ def on_chain_start(
10188
if name not in self._path_set:
10289
self._path.append(name)
10390
self._path_set.add(name)
91+
self._node_metrics[name] = LDAIMetrics(success=False)
10492
elif name.endswith('__tools'):
10593
stripped = name[: -len('__tools')]
10694
if stripped in self._node_keys:
107-
# Attribute tool events to the owning agent node
10895
self._run_to_node[run_id] = stripped
10996

11097
def on_chain_end(
@@ -121,9 +108,10 @@ def on_chain_end(
121108
start_ns = self._node_start_ns.pop(run_id, None)
122109
if start_ns is not None:
123110
elapsed_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
124-
self._node_duration_ms[node_key] = (
125-
self._node_duration_ms.get(node_key, 0) + elapsed_ms
126-
)
111+
metrics = self._node_metrics.get(node_key)
112+
if metrics is not None:
113+
metrics.success = True
114+
metrics.duration_ms = (metrics.duration_ms or 0) + elapsed_ms
127115

128116
def on_llm_end(
129117
self,
@@ -151,11 +139,14 @@ def on_llm_end(
151139
if usage is None:
152140
return
153141

154-
existing = self._node_tokens.get(node_key)
142+
metrics = self._node_metrics.get(node_key)
143+
if metrics is None:
144+
return
145+
existing = metrics.usage
155146
if existing is None:
156-
self._node_tokens[node_key] = usage
147+
metrics.usage = usage
157148
else:
158-
self._node_tokens[node_key] = TokenUsage(
149+
metrics.usage = TokenUsage(
159150
total=existing.total + usage.total,
160151
input=existing.input + usage.input,
161152
output=existing.output + usage.output,
@@ -179,64 +170,11 @@ def on_tool_end(
179170

180171
config_key = self._fn_name_to_config_key.get(name)
181172
if config_key is None:
182-
# Tool is not a registered functional tool (e.g. a handoff tool) — skip tracking.
183173
return
184-
if node_key not in self._node_tool_calls:
185-
self._node_tool_calls[node_key] = []
186-
self._node_tool_calls[node_key].append(config_key)
187-
188-
# ------------------------------------------------------------------
189-
# Flush
190-
# ------------------------------------------------------------------
191-
192-
async def flush(
193-
self, graph: AgentGraphDefinition, eval_tasks=None
194-
) -> List[JudgeResult]:
195-
"""
196-
Emit all collected per-node metrics to the LaunchDarkly trackers.
197-
198-
Call this once after the graph run completes.
199-
200-
:param graph: The AgentGraphDefinition whose nodes hold the LD config trackers.
201-
:param eval_tasks: Optional dict mapping node key to a list of awaitables that
202-
return judge evaluation results. Multiple tasks arise when a node is visited
203-
more than once (e.g. in a graph with cycles).
204-
:return: All judge results collected across all nodes.
205-
"""
206-
node_trackers: Dict[str, Any] = {}
207-
all_eval_results: List[JudgeResult] = []
208-
for node_key in self._path:
209-
if node_key in node_trackers:
210-
continue
211-
node = graph.get_node(node_key)
212-
if not node:
213-
continue
214-
config_tracker = node.get_config().create_tracker()
215-
if not config_tracker:
216-
continue
217-
node_trackers[node_key] = config_tracker
218-
219-
usage = self._node_tokens.get(node_key)
220-
if usage:
221-
config_tracker.track_tokens(usage)
222-
223-
duration = self._node_duration_ms.get(node_key)
224-
if duration is not None:
225-
config_tracker.track_duration(duration)
226-
227-
config_tracker.track_success()
228-
229-
for tool_key in self._node_tool_calls.get(node_key, []):
230-
config_tracker.track_tool_call(tool_key)
231-
232-
if not eval_tasks:
233-
continue
234-
235-
for eval_task in eval_tasks.get(node_key, []):
236-
results = await eval_task
237-
all_eval_results.extend(results)
238-
for r in results:
239-
if r.success:
240-
config_tracker.track_judge_result(r)
241-
242-
return all_eval_results
174+
metrics = self._node_metrics.get(node_key)
175+
if metrics is None:
176+
return
177+
if metrics.tool_calls is None:
178+
metrics.tool_calls = [config_key]
179+
else:
180+
metrics.tool_calls.append(config_key)

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Tests for LangChain Provider."""
22

3-
import pytest
43
from unittest.mock import AsyncMock, MagicMock
54

5+
import pytest
66
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7-
87
from ldai import LDMessage
98
from ldai.evaluator import Evaluator
109

@@ -404,6 +403,7 @@ class TestCreateAgent:
404403
def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
405404
"""Should create LangChainAgentRunner wrapping a compiled graph."""
406405
from unittest.mock import patch
406+
407407
from ldai_langchain import LangChainAgentRunner
408408

409409
mock_ai_config = MagicMock()
@@ -436,6 +436,7 @@ def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
436436
def test_creates_agent_runner_with_no_tools(self):
437437
"""Should create LangChainAgentRunner with no tool definitions."""
438438
from unittest.mock import patch
439+
439440
from ldai_langchain import LangChainAgentRunner
440441

441442
mock_ai_config = MagicMock()
@@ -522,6 +523,7 @@ class TestBuildTools:
522523

523524
def test_registers_sync_callable_as_structured_tool_func(self):
524525
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig
526+
525527
from ldai_langchain.langchain_helper import build_structured_tools
526528

527529
def sync_tool(x: str = '') -> str:
@@ -546,6 +548,7 @@ def sync_tool(x: str = '') -> str:
546548

547549
def test_registers_async_callable_as_structured_tool_coroutine(self):
548550
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig
551+
549552
from ldai_langchain.langchain_helper import build_structured_tools
550553

551554
async def async_tool(x: str = '') -> str:

0 commit comments

Comments
 (0)