Skip to content

Commit 226dfdf

Browse files
jsonbaileyclaude
andcommitted
feat: add ManagedAgentGraph, OpenAIAgentGraphRunner, LangGraphAgentGraphRunner
Implements PR 5 — ManagedAgentGraph + create_agent_graph(): ldai: - managed_agent_graph.py: ManagedAgentGraph wrapper holding AgentGraphRunner + AIGraphTracker; exposes run(), get_agent_graph_runner(), get_tracker() - LDAIClient.create_agent_graph(key, context, tools): resolves graph via agent_graph(), delegates to RunnerFactory, returns ManagedAgentGraph - Exports ManagedAgentGraph from top-level ldai package ldai_openai: - OpenAIAgentGraphRunner(AgentGraphRunner): builds agents via reverse_traverse using the openai-agents SDK; auto-tracks path, tool calls, handoffs, latency, invocation success/failure - OpenAIRunnerFactory.create_agent_graph(graph_def, tools) -> OpenAIAgentGraphRunner ldai_langchain: - LangGraphAgentGraphRunner(AgentGraphRunner): builds a LangGraph StateGraph via traverse(); auto-tracks latency and invocation success/failure - LangChainRunnerFactory.create_agent_graph(graph_def, tools) -> LangGraphAgentGraphRunner Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e2df180 commit 226dfdf

12 files changed

Lines changed: 963 additions & 2 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from ldai_langchain.langchain_helper import LangChainHelper
22
from ldai_langchain.langchain_model_runner import LangChainModelRunner
33
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
4+
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
45

56
__version__ = "0.1.0"
67

78
__all__ = [
89
'__version__',
910
'LangChainRunnerFactory',
11+
'LangGraphAgentGraphRunner',
1012
'LangChainHelper',
1113
'LangChainModelRunner',
1214
]

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_runner_factory.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1+
from typing import Any
2+
13
from ldai.models import AIConfigKind
24
from ldai.providers import AIProvider
3-
5+
from ldai.runners.types import ToolRegistry
46
from ldai_langchain.langchain_helper import LangChainHelper
57
from ldai_langchain.langchain_model_runner import LangChainModelRunner
68

79

810
class LangChainRunnerFactory(AIProvider):
911
"""LangChain ``AIProvider`` implementation for the LaunchDarkly AI SDK."""
1012

13+
def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any:
14+
"""
15+
Create a configured LangGraphAgentGraphRunner for the given graph definition.
16+
17+
:param graph_def: The AgentGraphDefinition to execute
18+
:param tools: Registry mapping tool names to callables (langchain-compatible)
19+
:return: LangGraphAgentGraphRunner ready to execute the graph
20+
"""
21+
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
22+
return LangGraphAgentGraphRunner(graph_def, tools)
23+
1124
def create_model(self, config: AIConfigKind) -> LangChainModelRunner:
1225
"""
1326
Create a configured LangChainModelRunner for the given AI config.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""LangGraph agent graph runner for LaunchDarkly AI SDK."""
2+
3+
import operator
4+
import time
5+
from typing import Annotated, Any, List
6+
7+
from ldai.agent_graph import AgentGraphDefinition, AgentGraphNode
8+
from ldai.providers.types import LDAIMetrics
9+
from ldai.runners.agent_graph_runner import AgentGraphRunner
10+
from ldai.runners.types import AgentGraphResult, ToolRegistry
11+
12+
13+
class LangGraphAgentGraphRunner(AgentGraphRunner):
14+
"""
15+
AgentGraphRunner implementation for LangGraph.
16+
17+
Builds a LangGraph StateGraph from an AgentGraphDefinition and
18+
ToolRegistry via traverse(), compiles it, and executes it with
19+
ainvoke(). Auto-tracks latency and invocation success/failure via
20+
the graph's AIGraphTracker.
21+
22+
Requires ``langgraph`` to be installed.
23+
"""
24+
25+
def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry):
26+
"""
27+
Initialize the runner.
28+
29+
:param graph: The AgentGraphDefinition to execute
30+
:param tools: Registry mapping tool names to callables (langchain-compatible)
31+
"""
32+
self._graph = graph
33+
self._tools = tools
34+
35+
async def run(self, input: Any) -> AgentGraphResult:
36+
"""
37+
Run the agent graph with the given input.
38+
39+
Builds a LangGraph StateGraph from the AgentGraphDefinition, compiles
40+
it, and invokes it. Tracks latency and invocation success/failure.
41+
42+
:param input: The string prompt to send to the agent graph
43+
:return: AgentGraphResult with the final output and metrics
44+
"""
45+
tracker = self._graph.get_tracker()
46+
start_time = time.time()
47+
try:
48+
try:
49+
from langchain.chat_models import init_chat_model
50+
from langchain_core.messages import AnyMessage, HumanMessage
51+
from langgraph.graph import END, START, StateGraph
52+
from typing_extensions import TypedDict
53+
except ImportError as exc:
54+
raise ImportError(
55+
"langgraph is required for LangGraphAgentGraphRunner. "
56+
"Install it with: pip install langgraph"
57+
) from exc
58+
59+
class WorkflowState(TypedDict):
60+
messages: Annotated[List[AnyMessage], operator.add]
61+
62+
agent_builder: StateGraph = StateGraph(WorkflowState)
63+
root_node = self._graph.root()
64+
root_key = root_node.get_key() if root_node else None
65+
tools_ref = self._tools
66+
67+
def handle_traversal(node: AgentGraphNode, ctx: dict) -> None:
68+
node_config = node.get_config()
69+
node_key = node.get_key()
70+
71+
model = None
72+
if node_config.model:
73+
lc_model = init_chat_model(model=node_config.model.name)
74+
tool_defs = node_config.model.get_parameter('tools') or []
75+
tool_fns = [
76+
tools_ref[t.get('name', '')]
77+
for t in tool_defs
78+
if t.get('name', '') in tools_ref
79+
]
80+
if tool_fns:
81+
lc_model = lc_model.bind_tools(tool_fns)
82+
model = lc_model
83+
84+
def invoke(state: WorkflowState) -> WorkflowState:
85+
if model:
86+
response = model.invoke(state['messages'])
87+
return {'messages': [response]}
88+
return state
89+
90+
invoke.__name__ = node_key
91+
92+
agent_builder.add_node(name=node_key, node=invoke)
93+
94+
if node_key == root_key:
95+
agent_builder.add_edge(START, node_key)
96+
97+
if node.is_terminal():
98+
agent_builder.add_edge(node_key, END)
99+
100+
for edge in node.get_edges():
101+
agent_builder.add_edge(node_key, edge.target_config)
102+
103+
return None
104+
105+
self._graph.traverse(fn=handle_traversal)
106+
compiled = agent_builder.compile()
107+
108+
result = await compiled.ainvoke(
109+
{'messages': [HumanMessage(content=str(input))]}
110+
)
111+
duration = int((time.time() - start_time) * 1000)
112+
113+
output = ''
114+
messages = result.get('messages', [])
115+
if messages:
116+
last = messages[-1]
117+
if hasattr(last, 'content'):
118+
output = str(last.content)
119+
120+
if tracker:
121+
tracker.track_latency(duration)
122+
tracker.track_invocation_success()
123+
124+
return AgentGraphResult(
125+
output=output,
126+
raw=result,
127+
metrics=LDAIMetrics(success=True),
128+
)
129+
except Exception:
130+
duration = int((time.time() - start_time) * 1000)
131+
if tracker:
132+
tracker.track_latency(duration)
133+
tracker.track_invocation_failure()
134+
return AgentGraphResult(
135+
output='',
136+
raw=None,
137+
metrics=LDAIMetrics(success=False),
138+
)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph()."""
2+
3+
import pytest
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
6+
from ldai.agent_graph import AgentGraphDefinition
7+
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
8+
from ldai.runners.types import AgentGraphResult, ToolRegistry
9+
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
10+
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
11+
12+
13+
def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
14+
root_config = AIAgentConfig(
15+
key='root-agent',
16+
enabled=enabled,
17+
model=ModelConfig(name='gpt-4'),
18+
provider=ProviderConfig(name='openai'),
19+
instructions='You are a helpful assistant.',
20+
tracker=MagicMock(),
21+
)
22+
graph_config = AIAgentGraphConfig(
23+
key='test-graph',
24+
root_config_key='root-agent',
25+
edges=[],
26+
enabled=enabled,
27+
)
28+
nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config})
29+
return AgentGraphDefinition(
30+
agent_graph=graph_config,
31+
nodes=nodes,
32+
context=MagicMock(),
33+
enabled=enabled,
34+
tracker=MagicMock(),
35+
)
36+
37+
38+
# --- Factory ---
39+
40+
def test_langchain_runner_factory_create_agent_graph_returns_runner():
41+
graph = _make_graph()
42+
tools: ToolRegistry = {'fetch_weather': lambda loc: f'weather in {loc}'}
43+
factory = LangChainRunnerFactory()
44+
runner = factory.create_agent_graph(graph, tools)
45+
assert isinstance(runner, LangGraphAgentGraphRunner)
46+
47+
48+
def test_langchain_runner_factory_create_agent_graph_wires_graph_and_tools():
49+
graph = _make_graph()
50+
tools: ToolRegistry = {}
51+
factory = LangChainRunnerFactory()
52+
runner = factory.create_agent_graph(graph, tools)
53+
assert runner._graph is graph
54+
assert runner._tools is tools
55+
56+
57+
# --- LangGraphAgentGraphRunner ---
58+
59+
def test_langgraph_runner_stores_graph_and_tools():
60+
graph = _make_graph()
61+
tools: ToolRegistry = {}
62+
runner = LangGraphAgentGraphRunner(graph, tools)
63+
assert runner._graph is graph
64+
assert runner._tools is tools
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_langgraph_runner_run_raises_when_langgraph_not_installed():
69+
graph = _make_graph()
70+
runner = LangGraphAgentGraphRunner(graph, {})
71+
72+
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
73+
result = await runner.run("test")
74+
assert isinstance(result, AgentGraphResult)
75+
assert result.metrics.success is False
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_langgraph_runner_run_tracks_failure_on_exception():
80+
graph = _make_graph()
81+
tracker = graph.get_tracker()
82+
runner = LangGraphAgentGraphRunner(graph, {})
83+
84+
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
85+
result = await runner.run("fail")
86+
87+
assert result.metrics.success is False
88+
tracker.track_invocation_failure.assert_called_once()
89+
tracker.track_latency.assert_called_once()
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_langgraph_runner_run_success():
94+
graph = _make_graph()
95+
tracker = graph.get_tracker()
96+
97+
mock_message = MagicMock()
98+
mock_message.content = "langgraph answer"
99+
100+
mock_compiled = MagicMock()
101+
mock_compiled.ainvoke = AsyncMock(return_value={'messages': [mock_message]})
102+
103+
mock_state_graph_instance = MagicMock()
104+
mock_state_graph_instance.add_node = MagicMock()
105+
mock_state_graph_instance.add_edge = MagicMock()
106+
mock_state_graph_instance.compile = MagicMock(return_value=mock_compiled)
107+
108+
mock_langgraph_graph = MagicMock()
109+
mock_langgraph_graph.END = 'END'
110+
mock_langgraph_graph.START = 'START'
111+
mock_langgraph_graph.StateGraph = MagicMock(return_value=mock_state_graph_instance)
112+
113+
mock_human_message = MagicMock()
114+
mock_lc_core_messages = MagicMock()
115+
mock_lc_core_messages.HumanMessage = MagicMock(return_value=mock_human_message)
116+
mock_lc_core_messages.AnyMessage = MagicMock()
117+
118+
mock_init_model = MagicMock()
119+
mock_init_model.return_value = MagicMock()
120+
mock_langchain_chat = MagicMock()
121+
mock_langchain_chat.init_chat_model = mock_init_model
122+
123+
with patch.dict('sys.modules', {
124+
'langgraph': MagicMock(),
125+
'langgraph.graph': mock_langgraph_graph,
126+
'langchain_core': MagicMock(),
127+
'langchain_core.messages': mock_lc_core_messages,
128+
'langchain': MagicMock(),
129+
'langchain.chat_models': mock_langchain_chat,
130+
'typing_extensions': __import__('typing_extensions'),
131+
}):
132+
runner = LangGraphAgentGraphRunner(graph, {})
133+
result = await runner.run("find restaurants")
134+
135+
assert isinstance(result, AgentGraphResult)
136+
assert result.output == "langgraph answer"
137+
assert result.metrics.success is True
138+
tracker.track_invocation_success.assert_called_once()
139+
tracker.track_latency.assert_called_once()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from ldai_openai.openai_agent_graph_runner import OpenAIAgentGraphRunner
12
from ldai_openai.openai_helper import OpenAIHelper
23
from ldai_openai.openai_model_runner import OpenAIModelRunner
34
from ldai_openai.openai_runner_factory import OpenAIRunnerFactory
45

56
__all__ = [
67
'OpenAIRunnerFactory',
8+
'OpenAIAgentGraphRunner',
79
'OpenAIHelper',
810
'OpenAIModelRunner',
911
]

0 commit comments

Comments
 (0)