Skip to content

Commit 51f3ba6

Browse files
committed
fix openai tool call tracking and lint issues
1 parent 45ac282 commit 51f3ba6

7 files changed

Lines changed: 68 additions & 57 deletions

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_agent_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from ldai.providers import AgentResult, AgentRunner
77
from ldai.providers.types import LDAIMetrics
88

9-
from ldai_langchain.langchain_helper import extract_last_message_content, sum_token_usage_from_messages
9+
from ldai_langchain.langchain_helper import (
10+
extract_last_message_content,
11+
sum_token_usage_from_messages,
12+
)
1013

1114

1215
class LangChainAgentRunner(AgentRunner):

packages/ai-providers/server-ai-langchain/src/ldai_langchain/langchain_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, List, Optional
1+
from typing import Any, Dict, List, Optional, Union
22

33
from langchain_core.language_models.chat_models import BaseChatModel
44
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_graph_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def build_node(node: AgentGraphNode, ctx: dict) -> Any:
185185
# Map runtime tool name → LD config key for metrics (function __name__
186186
# for callables; identity for native tool instances — see get_tool_calls_from_run_items).
187187
if is_agent_tool_instance(tool_fn):
188-
tool_name_map[f'{tool_fn.name}_call'] = tool_name
188+
tool_name_map[tool_fn.name] = tool_name
189189
else:
190190
tool_name_map[tool_fn.__name__] = tool_name
191191
agent_tools.append(registry_value_to_agent_tool(tool_fn))
@@ -290,10 +290,11 @@ def _track_tool_calls(self, result: Any, tracker: Any) -> None:
290290
"""Track all tool calls from the run result, attributed to the node that called them."""
291291
gk = tracker.graph_key if tracker is not None else None
292292
for agent_name, tool_fn_name in get_tool_calls_from_run_items(result.new_items):
293-
log.info(f"Tracking tool call: agent_name={agent_name}, tool_fn_name={tool_fn_name}")
294-
original_key = self._agent_name_map.get(agent_name, agent_name)
295-
tool_name = self._tool_name_map.get(tool_fn_name, '')
296-
node = self._graph.get_node(original_key)
293+
agent_key = self._agent_name_map.get(agent_name, agent_name)
294+
tool_name = self._tool_name_map.get(tool_fn_name)
295+
if tool_name is None:
296+
continue
297+
node = self._graph.get_node(agent_key)
297298
if node is None:
298299
continue
299300
config_tracker = node.get_config().tracker

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_agent_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from ldai.providers import AgentResult, AgentRunner, ToolRegistry
77
from ldai.providers.types import LDAIMetrics
88

9-
from ldai_openai.openai_helper import get_ai_usage_from_response, registry_value_to_agent_tool
9+
from ldai_openai.openai_helper import (
10+
get_ai_usage_from_response,
11+
registry_value_to_agent_tool,
12+
)
1013

1114

1215
class OpenAIAgentRunner(AgentRunner):

packages/ai-providers/server-ai-openai/src/ldai_openai/openai_helper.py

Lines changed: 14 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import typing
21
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
32

43
from ldai import LDMessage
@@ -7,7 +6,6 @@
76
from openai.types.chat import ChatCompletionMessageParam
87

98

10-
119
def convert_messages_to_openai(messages: List[LDMessage]) -> Iterable[ChatCompletionMessageParam]:
1210
"""
1311
Convert LaunchDarkly messages to OpenAI chat completion message format.
@@ -109,43 +107,9 @@ def normalize_tool_types(tool_definitions: List[Any]) -> List[Dict[str, Any]]:
109107
return result
110108

111109

112-
# Native tool raw_item type names don't always match the LD config key convention.
113-
_NATIVE_TOOL_TYPE_TO_CONFIG_KEY = {
114-
'web_search': 'web_search_tool',
115-
}
116-
117-
# ``agents.Tool`` is a typing.Union of concrete tool classes, not a runtime class.
118-
# Using ``isinstance(x, Tool)`` raises TypeError (subscripted generics / union checks).
119-
_AGENT_TOOL_TYPES: Optional[Tuple[type, ...]] = None
120-
121-
122-
def _concrete_agent_tool_types() -> Tuple[type, ...]:
123-
"""Resolve concrete classes behind ``agents.Tool`` (a Union alias)."""
124-
try:
125-
from agents import Tool as ToolUnion
126-
except ImportError:
127-
return ()
128-
args = typing.get_args(ToolUnion)
129-
if not args:
130-
return ()
131-
out: List[type] = []
132-
for a in args:
133-
origin = getattr(a, '__origin__', None)
134-
if origin is not None and isinstance(origin, type):
135-
out.append(origin)
136-
elif isinstance(a, type):
137-
out.append(a)
138-
return tuple(out)
139-
140-
141110
def is_agent_tool_instance(value: Any) -> bool:
142111
"""True if ``value`` is already an openai-agents tool object (not a plain callable)."""
143-
global _AGENT_TOOL_TYPES
144-
if _AGENT_TOOL_TYPES is None:
145-
_AGENT_TOOL_TYPES = _concrete_agent_tool_types()
146-
if not _AGENT_TOOL_TYPES:
147-
return False
148-
return isinstance(value, _AGENT_TOOL_TYPES)
112+
return not callable(value)
149113

150114

151115
def registry_value_to_agent_tool(value: Any) -> Any:
@@ -156,19 +120,27 @@ def registry_value_to_agent_tool(value: Any) -> Any:
156120
tool instances (e.g. ``WebSearchTool()``, ``FileSearchTool(...)``) are
157121
returned unchanged so they are not double-wrapped.
158122
"""
123+
if is_agent_tool_instance(value):
124+
return value
159125
try:
160126
from agents import function_tool
161127
except ImportError as exc:
162128
raise ImportError(
163129
"openai-agents is required for agent tools. "
164130
"Install it with: pip install openai-agents"
165131
) from exc
166-
167-
if is_agent_tool_instance(value):
168-
return value
169132
return function_tool(value)
170133

171134

135+
# Native tool response types do not match the SDK or LD tool name; this map aligns them.
136+
# Function tools are omitted—they already arrive as ``ResponseFunctionToolCall.name``.
137+
_RESPONSE_TYPE_TO_TOOL_NAME: Dict[str, str] = {
138+
'web_search_call': 'web_search',
139+
'file_search_call': 'file_search',
140+
'code_interpreter_call': 'code_interpreter',
141+
}
142+
143+
172144
def get_tool_calls_from_run_items(new_items: List[Any]) -> List[Tuple[str, str]]:
173145
"""
174146
Extract (agent_name, tool_name) pairs from RunResult.new_items.
@@ -197,9 +169,9 @@ def get_tool_calls_from_run_items(new_items: List[Any]) -> List[Tuple[str, str]]
197169
tool_name = raw.name
198170
else:
199171
raw_type = getattr(raw, 'type', None) or (raw.get('type') if isinstance(raw, dict) else None)
200-
if not raw_type:
172+
if not isinstance(raw_type, str):
201173
continue
202-
tool_name = _NATIVE_TOOL_TYPE_TO_CONFIG_KEY.get(raw_type, raw_type)
174+
tool_name = _RESPONSE_TYPE_TO_TOOL_NAME.get(raw_type, raw_type)
203175
if tool_name:
204176
result.append((agent_name, tool_name))
205177
return result

packages/ai-providers/server-ai-openai/tests/test_tracking_openai_agents.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ def _make_run_result(
109109
return result
110110

111111

112+
def _tool_registry(*config_names: str) -> dict:
113+
"""Registry entries whose callable __name__ matches runtime tool names from the SDK."""
114+
115+
def _stub(name: str):
116+
def fn():
117+
pass
118+
119+
fn.__name__ = name
120+
return fn
121+
122+
return {n: _stub(n) for n in config_names}
123+
124+
112125
def _make_tool_call_item(agent_name: str, tool_name: str) -> MagicMock:
113126
"""
114127
Create a mock ToolCallItem with a ResponseFunctionToolCall raw item so that
@@ -300,15 +313,15 @@ async def test_tracks_graph_key_on_node_events():
300313

301314
@pytest.mark.asyncio
302315
async def test_tracks_tool_calls_from_run_items():
303-
"""A tool_call event fires for each tool found in RunResult.new_items."""
316+
"""A tool_call event fires for tools registered on the graph and in the tool registry."""
304317
mock_ld_client = MagicMock()
305-
graph = _make_graph(mock_ld_client, node_key='root-agent')
318+
graph = _make_graph(mock_ld_client, node_key='root-agent', tool_names=['get_weather'])
306319

307320
tool_item = _make_tool_call_item('root-agent', 'get_weather')
308321
run_result = _make_run_result(output='done', tool_call_items=[tool_item])
309322

310323
with patch.dict('sys.modules', _make_agents_modules(run_result)):
311-
runner = OpenAIAgentGraphRunner(graph, {})
324+
runner = OpenAIAgentGraphRunner(graph, _tool_registry('get_weather'))
312325
await runner.run('What is the weather?')
313326

314327
ev = _events(mock_ld_client)
@@ -319,9 +332,11 @@ async def test_tracks_tool_calls_from_run_items():
319332

320333
@pytest.mark.asyncio
321334
async def test_tracks_multiple_tool_calls():
322-
"""One tool_call event fires per tool in RunResult.new_items."""
335+
"""One tool_call event fires per registered tool in RunResult.new_items."""
323336
mock_ld_client = MagicMock()
324-
graph = _make_graph(mock_ld_client, node_key='root-agent')
337+
graph = _make_graph(
338+
mock_ld_client, node_key='root-agent', tool_names=['search', 'summarize']
339+
)
325340

326341
items = [
327342
_make_tool_call_item('root-agent', 'search'),
@@ -330,14 +345,31 @@ async def test_tracks_multiple_tool_calls():
330345
run_result = _make_run_result(output='done', tool_call_items=items)
331346

332347
with patch.dict('sys.modules', _make_agents_modules(run_result)):
333-
runner = OpenAIAgentGraphRunner(graph, {})
348+
runner = OpenAIAgentGraphRunner(graph, _tool_registry('search', 'summarize'))
334349
await runner.run('Search and summarize.')
335350

336351
ev = _events(mock_ld_client)
337352
tool_keys = [data['toolKey'] for data, _ in ev.get('$ld:ai:tool_call', [])]
338353
assert sorted(tool_keys) == ['search', 'summarize']
339354

340355

356+
@pytest.mark.asyncio
357+
async def test_does_not_track_tool_calls_without_graph_and_registry_config():
358+
"""RunResult tool items that are not backed by graph + registry tools are ignored."""
359+
mock_ld_client = MagicMock()
360+
graph = _make_graph(mock_ld_client, node_key='root-agent')
361+
362+
tool_item = _make_tool_call_item('root-agent', 'orphan_tool')
363+
run_result = _make_run_result(output='done', tool_call_items=[tool_item])
364+
365+
with patch.dict('sys.modules', _make_agents_modules(run_result)):
366+
runner = OpenAIAgentGraphRunner(graph, {})
367+
await runner.run('prompt')
368+
369+
ev = _events(mock_ld_client)
370+
assert ev.get('$ld:ai:tool_call', []) == []
371+
372+
341373
@pytest.mark.asyncio
342374
async def test_tracks_failure_and_latency_on_runner_error():
343375
"""When Runner.run raises, failure and latency events fire; success does not."""

packages/sdk/server-ai/src/ldai/providers/runner_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from ldai import log
55
from ldai.models import AIConfigKind
6-
from ldai.providers.ai_provider import AIProvider
76
from ldai.providers.agent_graph_runner import AgentGraphRunner
87
from ldai.providers.agent_runner import AgentRunner
8+
from ldai.providers.ai_provider import AIProvider
99
from ldai.providers.model_runner import ModelRunner
1010

1111
T = TypeVar('T')

0 commit comments

Comments
 (0)