Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""

import asyncio
import time
from contextvars import ContextVar
from typing import Annotated, Any, Dict, List, Set, Tuple

from ldai import log
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
from ldai.providers.types import LDAIMetrics
from ldai.providers import AgentGraphRunner, ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult, GraphMetrics, LDAIMetrics

from ldai_langchain.langchain_helper import (
build_structured_tools,
Expand All @@ -18,9 +16,6 @@
)
from ldai_langchain.langgraph_callback_handler import LDMetricsCallbackHandler

# Per-run eval task accumulator, isolated per concurrent run() call via ContextVar.
_run_eval_tasks: ContextVar[Dict[str, List[asyncio.Task]]] = ContextVar('_run_eval_tasks')


def _make_handoff_tool(child_key: str, description: str) -> Any:
"""
Expand Down Expand Up @@ -65,9 +60,10 @@ class LangGraphAgentGraphRunner(AgentGraphRunner):

AgentGraphRunner implementation for LangGraph.

Compiles and runs the agent graph with LangGraph and automatically records
graph- and node-level AI metric data to the LaunchDarkly trackers on the
graph definition and each node.
Compiles and runs the agent graph with LangGraph and collects graph- and
node-level metrics via a LangChain callback handler. Tracking events are
emitted by the managed layer (:class:`~ldai.ManagedAgentGraph`) from the
returned :class:`~ldai.providers.types.AgentGraphRunnerResult`.

Requires ``langgraph`` to be installed.
"""
Expand Down Expand Up @@ -181,26 +177,6 @@ async def invoke(state: WorkflowState) -> dict:
if node_instructions:
msgs = [SystemMessage(content=node_instructions)] + msgs
response = await bound_model.ainvoke(msgs)

node_obj = self._graph.get_node(nk)
if node_obj is not None:
input_text = '\r\n'.join(
m.content if isinstance(m.content, str) else str(m.content)
for m in msgs
) if msgs else ''
output_text = (
response.content if hasattr(response, 'content') else str(response)
)
task = node_obj.get_config().evaluator.evaluate(input_text, output_text)
run_tasks = _run_eval_tasks.get(None)
if run_tasks is not None:
run_tasks.setdefault(nk, []).append(task)
else:
log.warning(
f"LangGraphAgentGraphRunner: eval task for node '{nk}' "
"has no run context; judge results will not be tracked"
)

return {'messages': [response]}

invoke.__name__ = nk
Expand Down Expand Up @@ -298,20 +274,18 @@ def route(state: WorkflowState) -> str:
compiled = agent_builder.compile()
return compiled, fn_name_to_config_key, node_keys

async def run(self, input: Any) -> AgentGraphResult:
async def run(self, input: Any) -> AgentGraphRunnerResult:
"""
Run the agent graph with the given input.

Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
it, and invokes it. Uses a LangChain callback handler to collect
per-node metrics, then flushes them to LaunchDarkly trackers.
per-node metrics. Graph-level tracking events are emitted by the
managed layer from the returned GraphMetrics.

:param input: The string prompt to send to the agent graph
:return: AgentGraphResult with the final output and metrics
:return: AgentGraphRunnerResult with the final content and GraphMetrics
"""
pending_eval_tasks: Dict[str, List[asyncio.Task]] = {}
token = _run_eval_tasks.set(pending_eval_tasks)
tracker = self._graph.create_tracker()
start_ns = time.perf_counter_ns()

try:
Expand All @@ -325,24 +299,34 @@ async def run(self, input: Any) -> AgentGraphResult:
config={'callbacks': [handler], 'recursion_limit': 25},
)

duration = (time.perf_counter_ns() - start_ns) // 1_000_000
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
messages = result.get('messages', [])
output = extract_last_message_content(messages)
total_usage = sum_token_usage_from_messages(messages)

# Build per-node LDAIMetrics from callback handler data
node_metrics: Dict[str, LDAIMetrics] = {}
for node_key in handler.path:
usage = handler.node_tokens.get(node_key)
duration = handler.node_durations_ms.get(node_key)
tool_calls = handler.node_tool_calls.get(node_key) or []
node_metrics[node_key] = LDAIMetrics(
success=True,
usage=usage,
duration_ms=duration,
tool_calls=tool_calls if tool_calls else None,
)
Comment thread
jsonbailey marked this conversation as resolved.
Outdated

# Flush per-node metrics to LD trackers; eval results are tracked
# internally and intentionally not exposed on AgentGraphResult here
# — judge dispatch is the managed layer's responsibility.
await handler.flush(self._graph, pending_eval_tasks)

tracker.track_path(handler.path)
tracker.track_duration(duration)
tracker.track_invocation_success()
tracker.track_total_tokens(sum_token_usage_from_messages(messages))

return AgentGraphResult(
output=output,
return AgentGraphRunnerResult(
content=output,
raw=result,
metrics=LDAIMetrics(success=True),
metrics=GraphMetrics(
success=True,
path=handler.path,
duration_ms=duration_ms,
usage=total_usage if (total_usage is not None and total_usage.total > 0) else None,
node_metrics=node_metrics,
),
)

except Exception as exc:
Expand All @@ -353,13 +337,12 @@ async def run(self, input: Any) -> AgentGraphResult:
)
else:
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
tracker.track_duration(duration)
tracker.track_invocation_failure()
return AgentGraphResult(
output='',
duration_ms = (time.perf_counter_ns() - start_ns) // 1_000_000
return AgentGraphRunnerResult(
content='',
raw=None,
metrics=LDAIMetrics(success=False),
metrics=GraphMetrics(
success=False,
duration_ms=duration_ms,
),
)
finally:
_run_eval_tasks.reset(token)
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import ChatGeneration, LLMResult
from ldai.agent_graph import AgentGraphDefinition
from ldai.providers.types import JudgeResult
from ldai.providers.types import LDAIMetrics
from ldai.tracker import TokenUsage

from ldai_langchain.langchain_helper import get_ai_usage_from_response
Expand All @@ -20,8 +19,9 @@ class LDMetricsCallbackHandler(BaseCallbackHandler):

LangChain callback handler that collects per-node metrics during a LangGraph run.

Records token usage, tool calls, and duration for each agent node in the graph,
then flushes them to LaunchDarkly trackers after the run completes via ``flush()``.
Records token usage, tool calls, and duration for each agent node in the graph.
Call ``collect_node_metrics()`` after the run completes to retrieve the accumulated
per-node metrics for use by the managed layer.
"""

def __init__(self, node_keys: Set[str], fn_name_to_config_key: Dict[str, str]):
Expand Down Expand Up @@ -185,58 +185,26 @@ def on_tool_end(
self._node_tool_calls[node_key] = []
self._node_tool_calls[node_key].append(config_key)

# ------------------------------------------------------------------
# Flush
# ------------------------------------------------------------------

async def flush(
self, graph: AgentGraphDefinition, eval_tasks=None
) -> List[JudgeResult]:
def collect_node_metrics(self) -> Dict[str, LDAIMetrics]:
"""
Emit all collected per-node metrics to the LaunchDarkly trackers.
Build a per-node ``LDAIMetrics`` map from data collected during the run.

Call this once after the graph run completes.
Pure data extraction — no LaunchDarkly tracker events are emitted.
:class:`LangGraphAgentGraphRunner` uses this to populate
``GraphMetrics.node_metrics`` so the managed layer can drive per-node
events.

:param graph: The AgentGraphDefinition whose nodes hold the LD config trackers.
:param eval_tasks: Optional dict mapping node key to a list of awaitables that
return judge evaluation results. Multiple tasks arise when a node is visited
more than once (e.g. in a graph with cycles).
:return: All judge results collected across all nodes.
:return: Mapping of node key to its accumulated ``LDAIMetrics``.
"""
node_trackers: Dict[str, Any] = {}
all_eval_results: List[JudgeResult] = []
node_metrics: Dict[str, LDAIMetrics] = {}
for node_key in self._path:
if node_key in node_trackers:
continue
node = graph.get_node(node_key)
if not node:
continue
config_tracker = node.get_config().create_tracker()
if not config_tracker:
continue
node_trackers[node_key] = config_tracker

usage = self._node_tokens.get(node_key)
if usage:
config_tracker.track_tokens(usage)

duration = self._node_duration_ms.get(node_key)
if duration is not None:
config_tracker.track_duration(duration)

config_tracker.track_success()

for tool_key in self._node_tool_calls.get(node_key, []):
config_tracker.track_tool_call(tool_key)

if not eval_tasks:
if node_key in node_metrics:
continue

for eval_task in eval_tasks.get(node_key, []):
results = await eval_task
all_eval_results.extend(results)
for r in results:
if r.success:
config_tracker.track_judge_result(r)

return all_eval_results
tool_calls = self._node_tool_calls.get(node_key, [])
node_metrics[node_key] = LDAIMetrics(
success=True,
usage=self._node_tokens.get(node_key),
tool_calls=list(tool_calls) if tool_calls else None,
duration_ms=self._node_duration_ms.get(node_key),
)
return node_metrics
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Tests for LangChain Provider."""

import pytest
from unittest.mock import AsyncMock, MagicMock

import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from ldai import LDMessage
from ldai.evaluator import Evaluator

Expand Down Expand Up @@ -404,6 +403,7 @@ class TestCreateAgent:
def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
"""Should create LangChainAgentRunner wrapping a compiled graph."""
from unittest.mock import patch

from ldai_langchain import LangChainAgentRunner

mock_ai_config = MagicMock()
Expand Down Expand Up @@ -436,6 +436,7 @@ def test_creates_agent_runner_with_instructions_and_tool_definitions(self):
def test_creates_agent_runner_with_no_tools(self):
"""Should create LangChainAgentRunner with no tool definitions."""
from unittest.mock import patch

from ldai_langchain import LangChainAgentRunner

mock_ai_config = MagicMock()
Expand Down Expand Up @@ -522,6 +523,7 @@ class TestBuildTools:

def test_registers_sync_callable_as_structured_tool_func(self):
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig

from ldai_langchain.langchain_helper import build_structured_tools

def sync_tool(x: str = '') -> str:
Expand All @@ -546,6 +548,7 @@ def sync_tool(x: str = '') -> str:

def test_registers_async_callable_as_structured_tool_coroutine(self):
from ldai.models import AIAgentConfig, ModelConfig, ProviderConfig

from ldai_langchain.langchain_helper import build_structured_tools

async def async_tool(x: str = '') -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph()."""

import pytest
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from ldai.agent_graph import AgentGraphDefinition
from ldai.evaluator import Evaluator
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
from ldai.providers import AgentGraphResult, ToolRegistry
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
from ldai.models import AIAgentConfig, AIAgentGraphConfig, ModelConfig, ProviderConfig
from ldai.providers import ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult

from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner


def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
Expand Down Expand Up @@ -75,22 +77,22 @@ async def test_langgraph_runner_run_raises_when_langgraph_not_installed():

with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("test")
assert isinstance(result, AgentGraphResult)
assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.success is False


@pytest.mark.asyncio
async def test_langgraph_runner_run_tracks_failure_on_exception():
async def test_langgraph_runner_run_returns_failure_on_exception():
"""Runner now returns AgentGraphRunnerResult; managed layer drives tracker events."""
graph = _make_graph()
tracker = graph.create_tracker()
runner = LangGraphAgentGraphRunner(graph, {})

with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("fail")

assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.success is False
tracker.track_invocation_failure.assert_called_once()
tracker.track_duration.assert_called_once()
assert result.metrics.duration_ms is not None


@pytest.mark.asyncio
Expand Down Expand Up @@ -147,9 +149,10 @@ async def test_langgraph_runner_run_success():
runner = LangGraphAgentGraphRunner(graph, {})
result = await runner.run("find restaurants")

assert isinstance(result, AgentGraphResult)
assert result.output == "langgraph answer"
assert result.metrics.success is True
tracker.track_path.assert_called_once_with([])
tracker.track_invocation_success.assert_called_once()
tracker.track_duration.assert_called_once()
assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.duration_ms is not None
# Tracker events now fire from the managed layer (ManagedAgentGraph) using
# result.metrics; the runner no longer touches the graph tracker directly.
tracker.track_path.assert_not_called()
tracker.track_invocation_success.assert_not_called()
tracker.track_duration.assert_not_called()
Loading
Loading