Skip to content

Commit c8318a2

Browse files
committed
adjusting native tool use and adding graph tests
1 parent 58d4414 commit c8318a2

6 files changed

Lines changed: 603 additions & 29 deletions

File tree

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""
2+
Integration tests for LangGraphAgentGraphRunner tracking pipeline.
3+
4+
Uses real AIGraphTracker and LDAIConfigTracker backed by a mock LD client,
5+
and a fake LangChain model to verify that the correct LD events are emitted
6+
with the correct payloads — without making real API calls.
7+
"""
8+
9+
import pytest
10+
from collections import defaultdict
11+
from unittest.mock import MagicMock, patch
12+
13+
from ldai.agent_graph import AgentGraphDefinition
14+
from ldai.models import AIAgentGraphConfig, AIAgentConfig, ModelConfig, ProviderConfig
15+
from ldai.tracker import AIGraphTracker, LDAIConfigTracker
16+
from ldai_langchain.langgraph_agent_graph_runner import LangGraphAgentGraphRunner
17+
18+
pytestmark = pytest.mark.skipif(
19+
pytest.importorskip('langgraph', reason='langgraph not installed') is None,
20+
reason='langgraph not installed',
21+
)
22+
23+
24+
# ---------------------------------------------------------------------------
25+
# Helpers
26+
# ---------------------------------------------------------------------------
27+
28+
def _make_graph(
29+
mock_ld_client: MagicMock,
30+
node_key: str = 'root-agent',
31+
graph_key: str = 'test-graph',
32+
tool_names: list = None,
33+
) -> AgentGraphDefinition:
34+
"""
35+
Build an AgentGraphDefinition backed by real tracker objects that record
36+
events to a mock LD client.
37+
"""
38+
context = MagicMock()
39+
40+
node_tracker = LDAIConfigTracker(
41+
ld_client=mock_ld_client,
42+
variation_key='test-variation',
43+
config_key=node_key,
44+
version=1,
45+
model_name='gpt-4',
46+
provider_name='openai',
47+
context=context,
48+
)
49+
50+
graph_tracker = AIGraphTracker(
51+
ld_client=mock_ld_client,
52+
variation_key='test-variation',
53+
graph_key=graph_key,
54+
version=1,
55+
context=context,
56+
)
57+
58+
tool_defs = (
59+
[{'name': name, 'type': 'function', 'description': '', 'parameters': {}}
60+
for name in tool_names]
61+
if tool_names else None
62+
)
63+
64+
root_config = AIAgentConfig(
65+
key=node_key,
66+
enabled=True,
67+
model=ModelConfig(name='gpt-4', parameters={'tools': tool_defs} if tool_defs else {}),
68+
provider=ProviderConfig(name='openai'),
69+
instructions='You are a helpful assistant.',
70+
tracker=node_tracker,
71+
)
72+
73+
graph_config = AIAgentGraphConfig(
74+
key=graph_key,
75+
root_config_key=node_key,
76+
edges=[],
77+
enabled=True,
78+
)
79+
80+
nodes = AgentGraphDefinition.build_nodes(graph_config, {node_key: root_config})
81+
return AgentGraphDefinition(
82+
agent_graph=graph_config,
83+
nodes=nodes,
84+
context=context,
85+
enabled=True,
86+
tracker=graph_tracker,
87+
)
88+
89+
90+
def _make_fake_response(
91+
content: str,
92+
input_tokens: int = 10,
93+
output_tokens: int = 5,
94+
tool_call_names: list = None,
95+
):
96+
"""Create a real AIMessage with usage metadata and optional tool calls."""
97+
from langchain_core.messages import AIMessage
98+
99+
tool_calls = [
100+
{'name': name, 'args': {}, 'id': f'call_{i}', 'type': 'tool_call'}
101+
for i, name in enumerate(tool_call_names or [])
102+
]
103+
104+
return AIMessage(
105+
content=content,
106+
tool_calls=tool_calls,
107+
usage_metadata={
108+
'input_tokens': input_tokens,
109+
'output_tokens': output_tokens,
110+
'total_tokens': input_tokens + output_tokens,
111+
},
112+
)
113+
114+
115+
def _events(mock_ld_client: MagicMock) -> dict:
116+
"""Return dict of event_name -> list of (data, value) from all track() calls."""
117+
result = defaultdict(list)
118+
for call in mock_ld_client.track.call_args_list:
119+
name, _ctx, data, value = call.args
120+
result[name].append((data, value))
121+
return dict(result)
122+
123+
124+
def _mock_model(response):
125+
"""Return a mock LangChain model that always returns response on invoke()."""
126+
model = MagicMock()
127+
model.invoke.return_value = response
128+
model.bind_tools.return_value = model
129+
return model
130+
131+
132+
# ---------------------------------------------------------------------------
133+
# Tests
134+
# ---------------------------------------------------------------------------
135+
136+
@pytest.mark.asyncio
137+
async def test_tracks_node_and_graph_tokens_on_success():
138+
"""Node-level and graph-level token events fire with the correct counts."""
139+
mock_ld_client = MagicMock()
140+
graph = _make_graph(mock_ld_client)
141+
fake_response = _make_fake_response('Sunny.', input_tokens=10, output_tokens=5)
142+
143+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
144+
return_value=_mock_model(fake_response)):
145+
runner = LangGraphAgentGraphRunner(graph, {})
146+
result = await runner.run("What's the weather?")
147+
148+
assert result.metrics.success is True
149+
assert result.output == 'Sunny.'
150+
151+
ev = _events(mock_ld_client)
152+
153+
# Node-level token events
154+
assert ev['$ld:ai:tokens:total'][0][1] == 15
155+
assert ev['$ld:ai:tokens:input'][0][1] == 10
156+
assert ev['$ld:ai:tokens:output'][0][1] == 5
157+
assert ev['$ld:ai:generation:success'][0][1] == 1
158+
assert '$ld:ai:duration:total' in ev
159+
160+
# Graph-level events
161+
assert ev['$ld:ai:graph:total_tokens'][0][1] == 15
162+
assert ev['$ld:ai:graph:invocation_success'][0][1] == 1
163+
assert '$ld:ai:graph:latency' in ev
164+
assert '$ld:ai:graph:path' in ev
165+
166+
167+
@pytest.mark.asyncio
168+
async def test_tracks_execution_path():
169+
"""The path event contains the executed node key."""
170+
mock_ld_client = MagicMock()
171+
graph = _make_graph(mock_ld_client, node_key='my-agent')
172+
fake_response = _make_fake_response('Done.')
173+
174+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
175+
return_value=_mock_model(fake_response)):
176+
runner = LangGraphAgentGraphRunner(graph, {})
177+
await runner.run('hello')
178+
179+
ev = _events(mock_ld_client)
180+
path_data = ev['$ld:ai:graph:path'][0][0]
181+
assert 'my-agent' in path_data['path']
182+
183+
184+
@pytest.mark.asyncio
185+
async def test_tracks_tool_calls():
186+
"""A tool_call event fires for each tool name found in the model response."""
187+
mock_ld_client = MagicMock()
188+
graph = _make_graph(mock_ld_client, tool_names=['get_weather'])
189+
fake_response = _make_fake_response('Calling tool.', tool_call_names=['get_weather'])
190+
191+
tool_registry = {'get_weather': lambda location='NYC': 'sunny'}
192+
193+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
194+
return_value=_mock_model(fake_response)):
195+
runner = LangGraphAgentGraphRunner(graph, tool_registry)
196+
await runner.run('What is the weather?')
197+
198+
ev = _events(mock_ld_client)
199+
tool_events = ev.get('$ld:ai:tool_call', [])
200+
assert len(tool_events) == 1
201+
assert tool_events[0][0]['toolKey'] == 'get_weather'
202+
203+
204+
@pytest.mark.asyncio
205+
async def test_tracks_multiple_tool_calls():
206+
"""One tool_call event fires per tool name in the response."""
207+
mock_ld_client = MagicMock()
208+
graph = _make_graph(mock_ld_client, tool_names=['search', 'summarize'])
209+
fake_response = _make_fake_response('Done.', tool_call_names=['search', 'summarize'])
210+
211+
tool_registry = {'search': lambda q='': q, 'summarize': lambda text='': text}
212+
213+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
214+
return_value=_mock_model(fake_response)):
215+
runner = LangGraphAgentGraphRunner(graph, tool_registry)
216+
await runner.run('Search and summarize.')
217+
218+
ev = _events(mock_ld_client)
219+
tool_keys = [data['toolKey'] for data, _ in ev.get('$ld:ai:tool_call', [])]
220+
assert sorted(tool_keys) == ['search', 'summarize']
221+
222+
223+
@pytest.mark.asyncio
224+
async def test_tracks_graph_key_on_node_events():
225+
"""Node-level events include the graphKey so they can be correlated to the graph."""
226+
mock_ld_client = MagicMock()
227+
graph = _make_graph(mock_ld_client, graph_key='my-graph')
228+
fake_response = _make_fake_response('OK.', input_tokens=5, output_tokens=3)
229+
230+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
231+
return_value=_mock_model(fake_response)):
232+
runner = LangGraphAgentGraphRunner(graph, {})
233+
await runner.run('hello')
234+
235+
ev = _events(mock_ld_client)
236+
token_data = ev['$ld:ai:tokens:total'][0][0]
237+
assert token_data.get('graphKey') == 'my-graph'
238+
239+
240+
@pytest.mark.asyncio
241+
async def test_tracks_failure_and_latency_on_model_error():
242+
"""When the model raises, failure and latency events fire; success does not."""
243+
mock_ld_client = MagicMock()
244+
graph = _make_graph(mock_ld_client)
245+
246+
error_model = MagicMock()
247+
error_model.invoke.side_effect = RuntimeError('model error')
248+
error_model.bind_tools.return_value = error_model
249+
250+
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
251+
return_value=error_model):
252+
runner = LangGraphAgentGraphRunner(graph, {})
253+
result = await runner.run('fail')
254+
255+
assert result.metrics.success is False
256+
257+
ev = _events(mock_ld_client)
258+
assert '$ld:ai:graph:invocation_failure' in ev
259+
assert '$ld:ai:graph:latency' in ev
260+
assert '$ld:ai:graph:invocation_success' not in ev

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from ldai.tracker import TokenUsage
1111

1212
from ldai_openai.openai_helper import (
13-
NATIVE_OPENAI_TOOLS,
1413
extract_usage_from_request_entry,
1514
get_ai_usage_from_response,
1615
get_tool_calls_from_run_items,
@@ -167,11 +166,6 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
167166
for tool_def in tool_defs:
168167
tool_name = tool_def.get('name', '')
169168

170-
# Check native OpenAI tools first, then fall back to ToolRegistry
171-
if tool_name in NATIVE_OPENAI_TOOLS:
172-
agent_tools.append(NATIVE_OPENAI_TOOLS[tool_name](tool_def))
173-
continue
174-
175169
tool_fn = self._tools.get(tool_name)
176170
if not tool_fn:
177171
continue

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_runner.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
from ldai.providers import AgentResult, AgentRunner, ToolRegistry
77
from ldai.providers.types import LDAIMetrics
88

9-
from ldai_openai.openai_helper import (
10-
NATIVE_OPENAI_TOOLS,
11-
get_ai_usage_from_response,
12-
)
9+
from ldai_openai.openai_helper import get_ai_usage_from_response
1310

1411

1512
class OpenAIAgentRunner(AgentRunner):
@@ -99,15 +96,9 @@ def _build_agent_tools(self) -> List[Any]:
9996
tools.append(function_tool(tool_fn))
10097
continue
10198

102-
# No callable in registry — try native OpenAI tool (exact name match required).
103-
native = NATIVE_OPENAI_TOOLS.get(name)
104-
if native:
105-
tools.append(native(td))
106-
continue
107-
10899
log.warning(
109100
f"Tool '{name}' is defined in the AI config but was not found in "
110-
"the tool registry and is not a known native tool; skipping."
101+
"the tool registry; skipping."
111102
)
112103
return tools
113104

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_helper.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,6 @@
66
from openai.types.chat import ChatCompletionMessageParam
77

88

9-
def _build_native_tool_map() -> Dict[str, Any]:
10-
try:
11-
from agents import WebSearchTool
12-
return {
13-
'web_search_tool': lambda _: WebSearchTool(),
14-
}
15-
except ImportError:
16-
return {}
17-
18-
19-
NATIVE_OPENAI_TOOLS: Dict[str, Any] = _build_native_tool_map()
20-
219

2210
def convert_messages_to_openai(messages: List[LDMessage]) -> Iterable[ChatCompletionMessageParam]:
2311
"""
@@ -91,6 +79,35 @@ def get_ai_metrics_from_response(response: Any) -> LDAIMetrics:
9179
return LDAIMetrics(success=True, usage=get_ai_usage_from_response(response))
9280

9381

82+
# Tool names that require their own API type in the Chat Completions API.
83+
# LD stores all tools as type="function"; these are converted to their correct type.
84+
_NATIVE_API_TOOL_NAMES = frozenset({
85+
'web_search_tool',
86+
'file_search',
87+
'computer_use_preview',
88+
})
89+
90+
91+
def normalize_tool_types(tool_definitions: List[Any]) -> List[Dict[str, Any]]:
92+
"""
93+
Convert LD tool definitions to Chat Completions API format.
94+
95+
LD emits all tools as ``type="function"`` with a flat structure. This helper
96+
wraps regular function tools in the nested ``function`` key the API requires,
97+
and converts known native tool names to their correct API type without a schema.
98+
99+
:param tool_definitions: Tool definitions from the LD AI config
100+
:return: Tool list ready to pass to ``chat.completions.create``
101+
"""
102+
result = []
103+
for td in tool_definitions:
104+
if not isinstance(td, dict):
105+
continue
106+
name = td.get('name', '')
107+
result.append({**td, 'type': name} if name in _NATIVE_API_TOOL_NAMES else td)
108+
return result
109+
110+
94111
# Native tool raw_item type names don't always match the LD config key convention.
95112
_NATIVE_TOOL_TYPE_TO_CONFIG_KEY = {
96113
'web_search': 'web_search_tool',

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_runner_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ldai.providers import AIProvider, ToolRegistry
66
from openai import AsyncOpenAI
77

8+
from ldai_openai.openai_helper import normalize_tool_types
89
from ldai_openai.openai_model_runner import OpenAIModelRunner
910

1011
if TYPE_CHECKING:
@@ -40,11 +41,17 @@ def create_model(self, config: AIConfigKind) -> OpenAIModelRunner:
4041
Create a configured OpenAIModelRunner for the given AI config.
4142
4243
Reuses the underlying AsyncOpenAI client so connection pooling is preserved.
44+
Tool definitions are converted from LD's flat format to the Chat Completions
45+
API format, with native tools mapped to their correct API type.
4346
4447
:param config: The LaunchDarkly AI configuration
4548
:return: OpenAIModelRunner ready to invoke the model
4649
"""
4750
model_name, parameters = self._extract_model_config(config)
51+
parameters = dict(parameters)
52+
tool_defs = parameters.pop('tools', None) or []
53+
if tool_defs:
54+
parameters['tools'] = normalize_tool_types(tool_defs)
4855
return OpenAIModelRunner(self._client, model_name, parameters)
4956

5057
def create_agent_graph(self, graph_def: Any, tools: ToolRegistry) -> Any:

0 commit comments

Comments
 (0)