-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_langgraph_agent_graph_runner.py
More file actions
158 lines (126 loc) · 5.63 KB
/
test_langgraph_agent_graph_runner.py
File metadata and controls
158 lines (126 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph()."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from ldai.agent_graph import AgentGraphDefinition
from ldai.evaluator import Evaluator
from ldai.models import AIAgentConfig, AIAgentGraphConfig, ModelConfig, ProviderConfig
from ldai.providers import ToolRegistry
from ldai.providers.types import AgentGraphRunnerResult
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
graph_tracker = MagicMock()
node_tracker = MagicMock()
root_config = AIAgentConfig(
key='root-agent',
enabled=enabled,
create_tracker=MagicMock(return_value=node_tracker),
model=ModelConfig(name='gpt-4'),
provider=ProviderConfig(name='openai'),
instructions='You are a helpful assistant.',
evaluator=Evaluator.noop(),
)
graph_config = AIAgentGraphConfig(
key='test-graph',
root_config_key='root-agent',
edges=[],
enabled=enabled,
)
nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config})
return AgentGraphDefinition(
agent_graph=graph_config,
nodes=nodes,
context=MagicMock(),
enabled=enabled,
create_tracker=lambda: graph_tracker,
)
# --- Factory ---
def test_langchain_runner_factory_create_agent_graph_returns_runner():
graph = _make_graph()
tools: ToolRegistry = {'fetch_weather': lambda loc: f'weather in {loc}'}
factory = LangChainRunnerFactory()
runner = factory.create_agent_graph(graph, tools)
assert isinstance(runner, LangGraphAgentGraphRunner)
def test_langchain_runner_factory_create_agent_graph_wires_graph_and_tools():
graph = _make_graph()
tools: ToolRegistry = {}
factory = LangChainRunnerFactory()
runner = factory.create_agent_graph(graph, tools)
assert runner._graph is graph
assert runner._tools is tools
# --- LangGraphAgentGraphRunner ---
def test_langgraph_runner_stores_graph_and_tools():
graph = _make_graph()
tools: ToolRegistry = {}
runner = LangGraphAgentGraphRunner(graph, tools)
assert runner._graph is graph
assert runner._tools is tools
@pytest.mark.asyncio
async def test_langgraph_runner_run_raises_when_langgraph_not_installed():
graph = _make_graph()
runner = LangGraphAgentGraphRunner(graph, {})
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("test")
assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.success is False
@pytest.mark.asyncio
async def test_langgraph_runner_run_returns_failure_on_exception():
"""Runner now returns AgentGraphRunnerResult; managed layer drives tracker events."""
graph = _make_graph()
runner = LangGraphAgentGraphRunner(graph, {})
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("fail")
assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.success is False
assert result.metrics.duration_ms is not None
@pytest.mark.asyncio
async def test_langgraph_runner_run_success():
graph = _make_graph()
tracker = graph.create_tracker()
mock_message = MagicMock()
mock_message.content = "langgraph answer"
mock_message.usage_metadata = None
mock_message.response_metadata = None
mock_compiled = MagicMock()
mock_compiled.ainvoke = AsyncMock(return_value={'messages': [mock_message]})
mock_state_graph_instance = MagicMock()
mock_state_graph_instance.add_node = MagicMock()
mock_state_graph_instance.add_edge = MagicMock()
mock_state_graph_instance.compile = MagicMock(return_value=mock_compiled)
mock_langgraph_graph = MagicMock()
mock_langgraph_graph.END = 'END'
mock_langgraph_graph.START = 'START'
mock_langgraph_graph.StateGraph = MagicMock(return_value=mock_state_graph_instance)
mock_human_message = MagicMock()
mock_lc_core_messages = MagicMock()
mock_lc_core_messages.HumanMessage = MagicMock(return_value=mock_human_message)
mock_lc_core_messages.AnyMessage = MagicMock()
mock_model_response = MagicMock()
mock_model_response.content = 'langgraph answer'
mock_model_response.usage_metadata = None
mock_model_response.response_metadata = None
mock_model_response.tool_calls = None
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(return_value=mock_model_response)
mock_init_model = MagicMock()
mock_init_model.return_value = mock_llm
mock_langchain_chat = MagicMock()
mock_langchain_chat.init_chat_model = mock_init_model
with patch.dict('sys.modules', {
'langgraph': MagicMock(),
'langgraph.graph': mock_langgraph_graph,
'langchain_core': MagicMock(),
'langchain_core.messages': mock_lc_core_messages,
'langchain': MagicMock(),
'langchain.chat_models': mock_langchain_chat,
'typing_extensions': __import__('typing_extensions'),
}):
runner = LangGraphAgentGraphRunner(graph, {})
result = await runner.run("find restaurants")
assert isinstance(result, AgentGraphRunnerResult)
assert result.metrics.duration_ms is not None
# Tracker events now fire from the managed layer (ManagedAgentGraph) using
# result.metrics; the runner no longer touches the graph tracker directly.
tracker.track_path.assert_not_called()
tracker.track_invocation_success.assert_not_called()
tracker.track_duration.assert_not_called()