-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtest_langgraph_agent_graph_runner.py
More file actions
155 lines (124 loc) · 5.43 KB
/
test_langgraph_agent_graph_runner.py
File metadata and controls
155 lines (124 loc) · 5.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Tests for LangGraphAgentGraphRunner and LangChainRunnerFactory.create_agent_graph()."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from ldai.agent_graph import AgentGraphDefinition
from ldai.evaluator import Evaluator
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
from ldai.providers import AgentGraphResult, ToolRegistry
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
def _make_graph(enabled: bool = True) -> AgentGraphDefinition:
graph_tracker = MagicMock()
node_tracker = MagicMock()
root_config = AIAgentConfig(
key='root-agent',
enabled=enabled,
create_tracker=MagicMock(return_value=node_tracker),
model=ModelConfig(name='gpt-4'),
provider=ProviderConfig(name='openai'),
instructions='You are a helpful assistant.',
evaluator=Evaluator.noop(),
)
graph_config = AIAgentGraphConfig(
key='test-graph',
root_config_key='root-agent',
edges=[],
enabled=enabled,
)
nodes = AgentGraphDefinition.build_nodes(graph_config, {'root-agent': root_config})
return AgentGraphDefinition(
agent_graph=graph_config,
nodes=nodes,
context=MagicMock(),
enabled=enabled,
create_tracker=lambda: graph_tracker,
)
# --- Factory ---
def test_langchain_runner_factory_create_agent_graph_returns_runner():
graph = _make_graph()
tools: ToolRegistry = {'fetch_weather': lambda loc: f'weather in {loc}'}
factory = LangChainRunnerFactory()
runner = factory.create_agent_graph(graph, tools)
assert isinstance(runner, LangGraphAgentGraphRunner)
def test_langchain_runner_factory_create_agent_graph_wires_graph_and_tools():
graph = _make_graph()
tools: ToolRegistry = {}
factory = LangChainRunnerFactory()
runner = factory.create_agent_graph(graph, tools)
assert runner._graph is graph
assert runner._tools is tools
# --- LangGraphAgentGraphRunner ---
def test_langgraph_runner_stores_graph_and_tools():
graph = _make_graph()
tools: ToolRegistry = {}
runner = LangGraphAgentGraphRunner(graph, tools)
assert runner._graph is graph
assert runner._tools is tools
@pytest.mark.asyncio
async def test_langgraph_runner_run_raises_when_langgraph_not_installed():
graph = _make_graph()
runner = LangGraphAgentGraphRunner(graph, {})
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("test")
assert isinstance(result, AgentGraphResult)
assert result.metrics.success is False
@pytest.mark.asyncio
async def test_langgraph_runner_run_tracks_failure_on_exception():
graph = _make_graph()
tracker = graph.create_tracker()
runner = LangGraphAgentGraphRunner(graph, {})
with patch.dict('sys.modules', {'langgraph': None, 'langgraph.graph': None}):
result = await runner.run("fail")
assert result.metrics.success is False
tracker.track_invocation_failure.assert_called_once()
tracker.track_duration.assert_called_once()
@pytest.mark.asyncio
async def test_langgraph_runner_run_success():
graph = _make_graph()
tracker = graph.create_tracker()
mock_message = MagicMock()
mock_message.content = "langgraph answer"
mock_message.usage_metadata = None
mock_message.response_metadata = None
mock_compiled = MagicMock()
mock_compiled.ainvoke = AsyncMock(return_value={'messages': [mock_message]})
mock_state_graph_instance = MagicMock()
mock_state_graph_instance.add_node = MagicMock()
mock_state_graph_instance.add_edge = MagicMock()
mock_state_graph_instance.compile = MagicMock(return_value=mock_compiled)
mock_langgraph_graph = MagicMock()
mock_langgraph_graph.END = 'END'
mock_langgraph_graph.START = 'START'
mock_langgraph_graph.StateGraph = MagicMock(return_value=mock_state_graph_instance)
mock_human_message = MagicMock()
mock_lc_core_messages = MagicMock()
mock_lc_core_messages.HumanMessage = MagicMock(return_value=mock_human_message)
mock_lc_core_messages.AnyMessage = MagicMock()
mock_model_response = MagicMock()
mock_model_response.content = 'langgraph answer'
mock_model_response.usage_metadata = None
mock_model_response.response_metadata = None
mock_model_response.tool_calls = None
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(return_value=mock_model_response)
mock_init_model = MagicMock()
mock_init_model.return_value = mock_llm
mock_langchain_chat = MagicMock()
mock_langchain_chat.init_chat_model = mock_init_model
with patch.dict('sys.modules', {
'langgraph': MagicMock(),
'langgraph.graph': mock_langgraph_graph,
'langchain_core': MagicMock(),
'langchain_core.messages': mock_lc_core_messages,
'langchain': MagicMock(),
'langchain.chat_models': mock_langchain_chat,
'typing_extensions': __import__('typing_extensions'),
}):
runner = LangGraphAgentGraphRunner(graph, {})
result = await runner.run("find restaurants")
assert isinstance(result, AgentGraphResult)
assert result.output == "langgraph answer"
assert result.metrics.success is True
tracker.track_path.assert_called_once_with([])
tracker.track_invocation_success.assert_called_once()
tracker.track_duration.assert_called_once()