Skip to content

Commit 56ce0fd

Browse files
authored
feat: Add ManagedAgentGraph support (#111)
feat: Add OpenAIAgentGraphRunner feat: Add LangGraphAgentGraphRunner
1 parent dc592c5 commit 56ce0fd

12 files changed

Lines changed: 1139 additions & 3 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
)
1010
from ldai_langchain.langchain_model_runner import LangChainModelRunner
1111
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
12+
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
1213

1314
__version__ = "0.1.0"
1415

1516
__all__ = [
1617
'__version__',
1718
'LangChainRunnerFactory',
19+
'LangGraphAgentGraphRunner',
1820
'LangChainModelRunner',
1921
'convert_messages_to_langchain',
2022
'create_langchain_model',

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from typing import Any
2+
13
from ldai.models import AIConfigKind
2-
from ldai.providers import AIProvider
4+
from ldai.providers import AIProvider, ToolRegistry
35

46
from ldai_langchain.langchain_helper import create_langchain_model
57
from ldai_langchain.langchain_model_runner import LangChainModelRunner
@@ -8,6 +10,19 @@
810
class LangChainRunnerFactory(AIProvider):
911
"""LangChain ``AIProvider`` implementation for the LaunchDarkly AI SDK."""
1012

13+
def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any:
14+
"""
15+
Create a configured LangGraphAgentGraphRunner for the given graph definition.
16+
17+
:param graph_def: The AgentGraphDefinition to execute
18+
:param tools: Registry mapping tool names to callables (langchain-compatible)
19+
:return: LangGraphAgentGraphRunner ready to execute the graph
20+
"""
21+
from ldai_langchain.langgraph_agent_graph_runner import (
22+
LangGraphAgentGraphRunner,
23+
)
24+
return LangGraphAgentGraphRunner(graph_def, tools)
25+
1126
def create_model(self, config: AIConfigKind) -> LangChainModelRunner:
1227
"""
1328
Create a configured LangChainModelRunner for the given AI config.
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
2+
3+
import operator
4+
import time
5+
from typing import Annotated, Any, List
6+
7+
from ldai import log
8+
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
9+
from ldai.providers import AgentGraphResult, AgentGraphRunner, ToolRegistry
10+
from ldai.providers.types import LDAIMetrics
11+
12+
from ldai_langchain.langchain_helper import (
13+
create_langchain_model,
14+
get_ai_metrics_from_response,
15+
get_ai_usage_from_response,
16+
get_tool_calls_from_response,
17+
sum_token_usage_from_messages,
18+
)
19+
20+
21+
class LangGraphAgentGraphRunner(AgentGraphRunner):
22+
"""
23+
AgentGraphRunner implementation for LangGraph.
24+
25+
Compiles and runs the agent graph with LangGraph and automatically records
26+
graph- and node-level AI metric data to the LaunchDarkly trackers on the
27+
graph definition and each node.
28+
29+
Requires ``langgraph`` to be installed.
30+
"""
31+
32+
def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry):
33+
"""
34+
Initialize the runner.
35+
36+
:param graph: The AgentGraphDefinition to execute
37+
:param tools: Registry mapping tool names to callables (langchain-compatible)
38+
"""
39+
self._graph = graph
40+
self._tools = tools
41+
42+
async def run(self, input: Any) -> AgentGraphResult:
43+
"""
44+
Run the agent graph with the given input.
45+
46+
Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
47+
it, and invokes it. Tracks latency and invocation success/failure.
48+
49+
:param input: The string prompt to send to the agent graph
50+
:return: AgentGraphResult with the final output and metrics
51+
"""
52+
tracker = self._graph.get_tracker()
53+
start_ns = time.perf_counter_ns()
54+
try:
55+
from langchain_core.messages import AnyMessage, HumanMessage
56+
from langgraph.graph import END, START, StateGraph
57+
from typing_extensions import TypedDict
58+
59+
class WorkflowState(TypedDict):
60+
messages: Annotated[List[Any], operator.add]
61+
62+
agent_builder: StateGraph = StateGraph(WorkflowState)
63+
root_node = self._graph.root()
64+
root_key = root_node.get_key() if root_node else None
65+
tools_ref = self._tools
66+
exec_path: List[str] = []
67+
68+
def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
69+
node_config = node.get_config()
70+
node_key = node.get_key()
71+
node_tracker = node_config.tracker
72+
73+
model = None
74+
if node_config.model:
75+
lc_model = create_langchain_model(node_config)
76+
tool_defs = node_config.model.get_parameter('tools') or []
77+
tool_fns = [
78+
tools_ref[t.get('name', '')]
79+
for t in tool_defs
80+
if t.get('name', '') in tools_ref
81+
]
82+
model = lc_model.bind_tools(tool_fns) if tool_fns else lc_model
83+
84+
def invoke(state: WorkflowState) -> WorkflowState:
85+
exec_path.append(node_key)
86+
if not model:
87+
return {'messages': []}
88+
gk = tracker.graph_key if tracker is not None else None
89+
if node_tracker:
90+
response = node_tracker.track_metrics_of(
91+
lambda: model.invoke(state['messages']),
92+
get_ai_metrics_from_response,
93+
graph_key=gk,
94+
)
95+
node_tracker.track_tool_calls(
96+
get_tool_calls_from_response(response),
97+
graph_key=tracker.graph_key if tracker is not None else None,
98+
)
99+
else:
100+
response = model.invoke(state['messages'])
101+
102+
return {'messages': [response]}
103+
104+
invoke.__name__ = node_key
105+
106+
agent_builder.add_node(node_key, invoke)
107+
108+
if node_key == root_key:
109+
agent_builder.add_edge(START, node_key)
110+
111+
if node.is_terminal():
112+
agent_builder.add_edge(node_key, END)
113+
114+
for edge in node.get_edges():
115+
agent_builder.add_edge(node_key, edge.target_config)
116+
117+
return None
118+
119+
self._graph.traverse(fn=handle_traversal)
120+
compiled = agent_builder.compile()
121+
122+
result = await compiled.ainvoke( # type: ignore[call-overload]
123+
{'messages': [HumanMessage(content=str(input))]}
124+
)
125+
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
126+
127+
output = ''
128+
messages = result.get('messages', [])
129+
if messages:
130+
last = messages[-1]
131+
if hasattr(last, 'content'):
132+
output = str(last.content)
133+
134+
if tracker:
135+
tracker.track_path(exec_path)
136+
tracker.track_latency(duration)
137+
tracker.track_invocation_success()
138+
tracker.track_total_tokens(
139+
sum_token_usage_from_messages(messages)
140+
)
141+
142+
return AgentGraphResult(
143+
output=output,
144+
raw=result,
145+
metrics=LDAIMetrics(success=True),
146+
)
147+
except Exception as exc:
148+
if isinstance(exc, ImportError):
149+
log.warning(
150+
"langgraph is required for LangGraphAgentGraphRunner. "
151+
"Install it with: pip install langgraph"
152+
)
153+
else:
154+
log.warning(f'LangGraphAgentGraphRunner run failed: {exc}')
155+
duration = (time.perf_counter_ns() - start_ns) // 1_000_000
156+
if tracker:
157+
tracker.track_latency(duration)
158+
tracker.track_invocation_failure()
159+
return AgentGraphResult(
160+
output='',
161+
raw=None,
162+
metrics=LDAIMetrics(success=False),
163+
)
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph()."""
2+
3+
import pytest
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
6+
from ldai.agent_graph import AgentGraphDefinition
7+
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
8+
from ldai.providers import AgentGraphResult, ToolRegistry
9+
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
10+
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
11+
12+
13+
def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
14+
root_config = AIAgentConfig(
15+
key='root-agent',
16+
enabled=enabled,
17+
model=ModelConfig(name='gpt-4'),
18+
provider=ProviderConfig(name='openai'),
19+
instructions='You are a helpful assistant.',
20+
tracker=MagicMock(),
21+
)
22+
graph_config = AIAgentGraphConfig(
23+
key='test-graph',
24+
root_config_key='root-agent',
25+
edges=[],
26+
enabled=enabled,
27+
)
28+
nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config})
29+
return AgentGraphDefinition(
30+
agent_graph=graph_config,
31+
nodes=nodes,
32+
context=MagicMock(),
33+
enabled=enabled,
34+
tracker=MagicMock(),
35+
)
36+
37+
38+
# --- Factory ---
39+
40+
def test_langchain_runner_factory_create_agent_graph_returns_runner():
41+
graph = _make_graph()
42+
tools: ToolRegistry = {'fetch_weather': lambda loc: f'weather in {loc}'}
43+
factory = LangChainRunnerFactory()
44+
runner = factory.create_agent_graph(graph, tools)
45+
assert isinstance(runner, LangGraphAgentGraphRunner)
46+
47+
48+
def test_langchain_runner_factory_create_agent_graph_wires_graph_and_tools():
49+
graph = _make_graph()
50+
tools: ToolRegistry = {}
51+
factory = LangChainRunnerFactory()
52+
runner = factory.create_agent_graph(graph, tools)
53+
assert runner._graph is graph
54+
assert runner._tools is tools
55+
56+
57+
# --- LangGraphAgentGraphRunner ---
58+
59+
def test_langgraph_runner_stores_graph_and_tools():
60+
graph = _make_graph()
61+
tools: ToolRegistry = {}
62+
runner = LangGraphAgentGraphRunner(graph, tools)
63+
assert runner._graph is graph
64+
assert runner._tools is tools
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_langgraph_runner_run_raises_when_langgraph_not_installed():
69+
graph = _make_graph()
70+
runner = LangGraphAgentGraphRunner(graph, {})
71+
72+
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
73+
result = await runner.run("test")
74+
assert isinstance(result, AgentGraphResult)
75+
assert result.metrics.success is False
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_langgraph_runner_run_tracks_failure_on_exception():
80+
graph = _make_graph()
81+
tracker = graph.get_tracker()
82+
runner = LangGraphAgentGraphRunner(graph, {})
83+
84+
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
85+
result = await runner.run("fail")
86+
87+
assert result.metrics.success is False
88+
tracker.track_invocation_failure.assert_called_once()
89+
tracker.track_latency.assert_called_once()
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_langgraph_runner_run_success():
94+
graph = _make_graph()
95+
tracker = graph.get_tracker()
96+
97+
mock_message = MagicMock()
98+
mock_message.content = "langgraph answer"
99+
mock_message.usage_metadata = None
100+
mock_message.response_metadata = None
101+
102+
mock_compiled = MagicMock()
103+
mock_compiled.ainvoke = AsyncMock(return_value={'messages': [mock_message]})
104+
105+
mock_state_graph_instance = MagicMock()
106+
mock_state_graph_instance.add_node = MagicMock()
107+
mock_state_graph_instance.add_edge = MagicMock()
108+
mock_state_graph_instance.compile = MagicMock(return_value=mock_compiled)
109+
110+
mock_langgraph_graph = MagicMock()
111+
mock_langgraph_graph.END = 'END'
112+
mock_langgraph_graph.START = 'START'
113+
mock_langgraph_graph.StateGraph = MagicMock(return_value=mock_state_graph_instance)
114+
115+
mock_human_message = MagicMock()
116+
mock_lc_core_messages = MagicMock()
117+
mock_lc_core_messages.HumanMessage = MagicMock(return_value=mock_human_message)
118+
mock_lc_core_messages.AnyMessage = MagicMock()
119+
120+
mock_model_response = MagicMock()
121+
mock_model_response.content = 'langgraph answer'
122+
mock_model_response.usage_metadata = None
123+
mock_model_response.response_metadata = None
124+
mock_model_response.tool_calls = None
125+
126+
mock_llm = MagicMock()
127+
mock_llm.invoke = MagicMock(return_value=mock_model_response)
128+
129+
mock_init_model = MagicMock()
130+
mock_init_model.return_value = mock_llm
131+
mock_langchain_chat = MagicMock()
132+
mock_langchain_chat.init_chat_model = mock_init_model
133+
134+
with patch.dict('sys.modules', {
135+
'langgraph': MagicMock(),
136+
'langgraph.graph': mock_langgraph_graph,
137+
'langchain_core': MagicMock(),
138+
'langchain_core.messages': mock_lc_core_messages,
139+
'langchain': MagicMock(),
140+
'langchain.chat_models': mock_langchain_chat,
141+
'typing_extensions': __import__('typing_extensions'),
142+
}):
143+
runner = LangGraphAgentGraphRunner(graph, {})
144+
result = await runner.run("find restaurants")
145+
146+
assert isinstance(result, AgentGraphResult)
147+
assert result.output == "langgraph answer"
148+
assert result.metrics.success is True
149+
tracker.track_path.assert_called_once_with([])
150+
tracker.track_invocation_success.assert_called_once()
151+
tracker.track_latency.assert_called_once()

packages/ai-providers/server-ai-openai/src/ldai_openai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner
12
from ldai_openai.openai_helper import (
23
convert_messages_to_openai,
34
get_ai_metrics_from_response,
@@ -8,6 +9,7 @@
89

910
__all__ = [
1011
'OpenAIRunnerFactory',
12+
'OpenAIAgentGraphRunner',
1113
'OpenAIModelRunner',
1214
'convert_messages_to_openai',
1315
'get_ai_metrics_from_response',

0 commit comments

Comments
 (0)