Skip to content

Commit c9cc0a9

Browse files
committed
fix sending proper tool name in openai
1 parent 69ba99f commit c9cc0a9

2 files changed

Lines changed: 39 additions & 7 deletions

File tree

packages/ai-providers/server-ai-langchain/tests/test_tracking_langgraph.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,23 @@ async def test_tracks_tool_calls():
254254
"""A tool_call event fires for each tool name found in the model response."""
255255
mock_ld_client = MagicMock()
256256
graph = _make_graph(mock_ld_client, tool_names=['get_weather'])
257-
fake_response = _make_fake_response('Calling tool.', tool_call_names=['get_weather'])
258257

259-
tool_registry = {'get_weather': lambda location='NYC': 'sunny'}
258+
# Model returns a tool call on the first invoke, then a final answer.
259+
tool_response = _make_fake_response('Calling tool.', tool_call_names=['get_weather'])
260+
final_response = _make_fake_response('It is sunny in NYC.')
261+
262+
mock_model = MagicMock()
263+
mock_model.invoke.side_effect = [tool_response, final_response]
264+
mock_model.bind_tools.return_value = mock_model
265+
266+
def get_weather(location: str = 'NYC') -> str:
267+
"""Return the current weather for a location."""
268+
return 'sunny'
269+
270+
tool_registry = {'get_weather': get_weather}
260271

261272
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
262-
return_value=_mock_model(fake_response)):
273+
return_value=mock_model):
263274
runner = LangGraphAgentGraphRunner(graph, tool_registry)
264275
await runner.run('What is the weather?')
265276

@@ -274,12 +285,27 @@ async def test_tracks_multiple_tool_calls():
274285
"""One tool_call event fires per tool name in the response."""
275286
mock_ld_client = MagicMock()
276287
graph = _make_graph(mock_ld_client, tool_names=['search', 'summarize'])
277-
fake_response = _make_fake_response('Done.', tool_call_names=['search', 'summarize'])
278288

279-
tool_registry = {'search': lambda q='': q, 'summarize': lambda text='': text}
289+
# Both tools called in one response; second invoke returns a final answer.
290+
tool_response = _make_fake_response('Done.', tool_call_names=['search', 'summarize'])
291+
final_response = _make_fake_response('Here is the summary.')
292+
293+
mock_model = MagicMock()
294+
mock_model.invoke.side_effect = [tool_response, final_response]
295+
mock_model.bind_tools.return_value = mock_model
296+
297+
def search(q: str = '') -> str:
298+
"""Search for information."""
299+
return q
300+
301+
def summarize(text: str = '') -> str:
302+
"""Summarize the given text."""
303+
return text
304+
305+
tool_registry = {'search': search, 'summarize': summarize}
280306

281307
with patch('ldai_langchain.langgraph_agent_graph_runner.create_langchain_model',
282-
return_value=_mock_model(fake_response)):
308+
return_value=mock_model):
283309
runner = LangGraphAgentGraphRunner(graph, tool_registry)
284310
await runner.run('Search and summarize.')
285311

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, graph: AgentGraphDefinition, tools: ToolRegistry):
5151
self._graph = graph
5252
self._tools = tools
5353
self._agent_name_map: Dict[str, str] = {}
54+
self._tool_name_map: Dict[str, str] = {}
5455

5556
async def run(self, input: Any) -> AgentGraphResult:
5657
"""
@@ -140,6 +141,7 @@ def _build_agents(self, path: List[str], state: _RunState) -> Any:
140141

141142
tracker = self._graph.get_tracker()
142143
name_map: Dict[str, str] = {}
144+
tool_name_map: Dict[str, str] = {}
143145

144146
def build_node(node: AgentGraphNode, ctx: dict) -> Any:
145147
node_config = node.get_config()
@@ -180,6 +182,8 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
180182
if not tool_fn:
181183
continue
182184

185+
# Map fn.__name__ → config key so tracked names match the AI config.
186+
tool_name_map[tool_fn.__name__] = tool_name
183187
agent_tools.append(function_tool(tool_fn))
184188

185189
return Agent(
@@ -192,6 +196,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
192196

193197
root = self._graph.reverse_traverse(fn=build_node)
194198
self._agent_name_map = name_map
199+
self._tool_name_map = tool_name_map
195200
return root
196201

197202
def _make_on_handoff(
@@ -280,8 +285,9 @@ def _flush_final_segment(
280285
def _track_tool_calls(self, result: Any, tracker: Any) -> None:
281286
"""Track all tool calls from the run result, attributed to the node that called them."""
282287
gk = tracker.graph_key if tracker is not None else None
283-
for agent_name, tool_name in get_tool_calls_from_run_items(result.new_items):
288+
for agent_name, tool_fn_name in get_tool_calls_from_run_items(result.new_items):
284289
original_key = self._agent_name_map.get(agent_name, agent_name)
290+
tool_name = self._tool_name_map.get(tool_fn_name, '')
285291
node = self._graph.get_node(original_key)
286292
if node is None:
287293
continue

0 commit comments

Comments
 (0)