Skip to content

Commit d331775

Browse files
jsonbaileyclaude
andcommitted
fix: export get_tool_calls_from_response and sum_token_usage_from_messages; add tests
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 91cd300 commit d331775

2 files changed

Lines changed: 89 additions & 1 deletion

File tree

packages/ai-providers/server-ai-langchain/src/ldai_langchain/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
create_langchain_model,
44
get_ai_metrics_from_response,
55
get_ai_usage_from_response,
6+
get_tool_calls_from_response,
67
map_provider,
8+
sum_token_usage_from_messages,
79
)
810
from ldai_langchain.langchain_model_runner import LangChainModelRunner
911
from ldai_langchain.langchain_runner_factory import LangChainRunnerFactory
@@ -18,5 +20,7 @@
1820
'create_langchain_model',
1921
'get_ai_metrics_from_response',
2022
'get_ai_usage_from_response',
23+
'get_tool_calls_from_response',
2124
'map_provider',
25+
'sum_token_usage_from_messages',
2226
]

packages/ai-providers/server-ai-langchain/tests/test_langchain_provider.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77

88
from ldai import LDMessage
99

10-
from ldai_langchain import LangChainModelRunner, LangChainRunnerFactory, convert_messages_to_langchain, get_ai_metrics_from_response, map_provider
10+
from ldai_langchain import (
11+
LangChainModelRunner,
12+
LangChainRunnerFactory,
13+
convert_messages_to_langchain,
14+
get_ai_metrics_from_response,
15+
get_tool_calls_from_response,
16+
map_provider,
17+
sum_token_usage_from_messages,
18+
)
1119

1220

1321
class TestConvertMessages:
@@ -237,6 +245,82 @@ async def test_returns_success_false_when_structured_model_invocation_throws_err
237245
assert result.metrics.usage is None
238246

239247

248+
class TestGetToolCallsFromResponse:
249+
"""Tests for get_tool_calls_from_response."""
250+
251+
def test_returns_tool_call_names_in_order(self):
252+
"""Should return tool call names from response.tool_calls."""
253+
mock_response = MagicMock()
254+
mock_response.tool_calls = [
255+
{'name': 'search', 'args': {}},
256+
{'name': 'calculator', 'args': {}},
257+
]
258+
assert get_tool_calls_from_response(mock_response) == ['search', 'calculator']
259+
260+
def test_returns_empty_list_when_tool_calls_is_empty(self):
261+
"""Should return empty list when tool_calls is an empty list."""
262+
mock_response = MagicMock()
263+
mock_response.tool_calls = []
264+
assert get_tool_calls_from_response(mock_response) == []
265+
266+
def test_returns_empty_list_when_no_tool_calls_attribute(self):
267+
"""Should return empty list when response has no tool_calls attribute."""
268+
mock_response = MagicMock(spec=[])
269+
assert get_tool_calls_from_response(mock_response) == []
270+
271+
def test_returns_empty_list_when_tool_calls_is_not_a_list(self):
272+
"""Should return empty list when tool_calls is not a list."""
273+
mock_response = MagicMock()
274+
mock_response.tool_calls = 'not-a-list'
275+
assert get_tool_calls_from_response(mock_response) == []
276+
277+
def test_skips_tool_calls_without_name(self):
278+
"""Should skip tool calls that have no name."""
279+
mock_response = MagicMock()
280+
mock_response.tool_calls = [{'args': {}}, {'name': 'search', 'args': {}}]
281+
assert get_tool_calls_from_response(mock_response) == ['search']
282+
283+
284+
class TestSumTokenUsageFromMessages:
285+
"""Tests for sum_token_usage_from_messages."""
286+
287+
def test_sums_usage_across_messages(self):
288+
"""Should sum token usage from all messages."""
289+
msg1 = AIMessage(content='a')
290+
msg1.usage_metadata = {'total_tokens': 10, 'input_tokens': 6, 'output_tokens': 4}
291+
msg2 = AIMessage(content='b')
292+
msg2.usage_metadata = {'total_tokens': 20, 'input_tokens': 12, 'output_tokens': 8}
293+
294+
result = sum_token_usage_from_messages([msg1, msg2])
295+
296+
assert result is not None
297+
assert result.total == 30
298+
assert result.input == 18
299+
assert result.output == 12
300+
301+
def test_returns_none_when_no_usage_on_any_message(self):
302+
"""Should return None when no message has usage metadata."""
303+
msg = AIMessage(content='hello')
304+
assert sum_token_usage_from_messages([msg]) is None
305+
306+
def test_returns_none_for_empty_list(self):
307+
"""Should return None for an empty message list."""
308+
assert sum_token_usage_from_messages([]) is None
309+
310+
def test_skips_messages_without_usage(self):
311+
"""Should skip messages that have no usage and sum the rest."""
312+
msg1 = AIMessage(content='a')
313+
msg2 = AIMessage(content='b')
314+
msg2.usage_metadata = {'total_tokens': 5, 'input_tokens': 3, 'output_tokens': 2}
315+
316+
result = sum_token_usage_from_messages([msg1, msg2])
317+
318+
assert result is not None
319+
assert result.total == 5
320+
assert result.input == 3
321+
assert result.output == 2
322+
323+
240324
class TestGetLlm:
241325
"""Tests for LangChainModelRunner.get_llm."""
242326

0 commit comments

Comments
 (0)