diff --git a/python/beeai_framework/memory/utils.py b/python/beeai_framework/memory/utils.py index 0dbe3988b..0c893798d 100644 --- a/python/beeai_framework/memory/utils.py +++ b/python/beeai_framework/memory/utils.py @@ -18,11 +18,12 @@ def extract_last_tool_call_pair(memory: BaseMemory) -> tuple[AssistantMessage, T return None tool_call: AssistantMessage = memory.messages[tool_call_index] # type: ignore + tool_call_id = tool_call.get_tool_calls()[-1].id tool_response_index = find_index( memory.messages, lambda msg: bool( - isinstance(msg, ToolMessage) and msg.get_tool_results()[0].tool_call_id == tool_call.get_tool_calls()[0].id + isinstance(msg, ToolMessage) and any(result.tool_call_id == tool_call_id for result in msg.content) ), reverse_traversal=True, fallback=-1, diff --git a/python/tests/memory/test_utils.py b/python/tests/memory/test_utils.py new file mode 100644 index 000000000..482872808 --- /dev/null +++ b/python/tests/memory/test_utils.py @@ -0,0 +1,58 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +from beeai_framework.backend.message import ( + AssistantMessage, + MessageToolCallContent, + MessageToolResultContent, + ToolMessage, +) +from beeai_framework.memory.unconstrained_memory import UnconstrainedMemory +from beeai_framework.memory.utils import extract_last_tool_call_pair + + +def test_extract_last_tool_call_pair_returns_matching_tool_response() -> None: + memory = UnconstrainedMemory() + assistant_message = AssistantMessage( + MessageToolCallContent(id="call_1", tool_name="weather", args='{"city":"Paris"}') + ) + tool_message = ToolMessage(MessageToolResultContent(tool_name="weather", tool_call_id="call_1", result="sunny")) + + asyncio.run(memory.add_many([assistant_message, tool_message])) + + assert extract_last_tool_call_pair(memory) == (assistant_message, tool_message) + + +def test_extract_last_tool_call_pair_ignores_empty_tool_messages() -> None: + memory = UnconstrainedMemory() + assistant_message = AssistantMessage( + MessageToolCallContent(id="call_1", tool_name="weather", args='{"city":"Paris"}') + ) + empty_tool_message = ToolMessage([]) + tool_message = ToolMessage(MessageToolResultContent(tool_name="weather", tool_call_id="call_1", result="sunny")) + + asyncio.run(memory.add_many([assistant_message, empty_tool_message, tool_message])) + + assert extract_last_tool_call_pair(memory) == (assistant_message, tool_message) + + +def test_extract_last_tool_call_pair_uses_last_tool_call_from_message() -> None: + memory = UnconstrainedMemory() + assistant_message = AssistantMessage( + [ + MessageToolCallContent(id="call_1", tool_name="weather", args='{"city":"Paris"}'), + MessageToolCallContent(id="call_2", tool_name="weather", args='{"city":"Berlin"}'), + ] + ) + first_tool_message = ToolMessage( + MessageToolResultContent(tool_name="weather", tool_call_id="call_1", result="sunny") + ) + last_tool_message = ToolMessage( + MessageToolResultContent(tool_name="weather", tool_call_id="call_2", result="rainy") + ) + + asyncio.run(memory.add_many([assistant_message, first_tool_message, last_tool_message])) + + assert extract_last_tool_call_pair(memory) == (assistant_message, last_tool_message)