Skip to content

Commit 6f06529

Browse files
authored
fix: Fix issue where utility function get_tool_calls_and_descriptions_from_snapshot modifies agent_snapshot in place (#421)
* Fix issue where utility function modifies agent snapshot in place * Pin lazy imports to be compatible with python 3.9
1 parent fe78d16 commit 6f06529

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

haystack_experimental/components/agents/human_in_the_loop/breakpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from copy import deepcopy
6+
57
from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint
68
from haystack.utils import _deserialize_value_with_schema
79

@@ -31,7 +33,7 @@ def get_tool_calls_and_descriptions_from_snapshot(
3133
tool_caused_break_point = break_point.tool_name
3234

3335
# Deserialize the tool invoker inputs from the snapshot
34-
tool_invoker_inputs = _deserialize_value_with_schema(agent_snapshot.component_inputs["tool_invoker"])
36+
tool_invoker_inputs = _deserialize_value_with_schema(deepcopy(agent_snapshot.component_inputs["tool_invoker"]))
3537
tool_call_messages = tool_invoker_inputs["messages"]
3638
state = tool_invoker_inputs["state"]
3739
tool_name_to_tool = {t.name: t for t in tool_invoker_inputs["tools"]}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ classifiers = [
2929
dependencies = [
3030
"haystack-ai",
3131
"rich", # For pretty printing in the console used by human-in-the-loop utilities
32+
"lazy-imports<1.2.0" # 1.2.0 requires Python 3.10+, see https://github.com/bachorp/lazy-imports/releases/tag/1.2.0
3233
]
3334

3435
[project.urls]

test/components/agents/test_agent.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import copy
56
import os
67
from pathlib import Path
78
from typing import Any, Optional
@@ -12,13 +13,11 @@
1213
from haystack.components.generators.chat import OpenAIChatGenerator
1314
from haystack.core.errors import BreakpointException
1415
from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
15-
from haystack.dataclasses import ChatMessage
16+
from haystack.dataclasses import ChatMessage, ToolCall
1617
from haystack.dataclasses.breakpoints import PipelineSnapshot
1718
from haystack.tools import Tool, create_tool_from_function
1819

1920
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
20-
from haystack_experimental.components.retrievers import ChatMessageRetriever
21-
from haystack_experimental.components.writers import ChatMessageWriter
2221
from haystack_experimental.components.agents.agent import Agent
2322
from haystack_experimental.components.agents.human_in_the_loop import (
2423
AlwaysAskPolicy,
@@ -34,6 +33,8 @@
3433
from haystack_experimental.components.agents.human_in_the_loop.breakpoint import (
3534
get_tool_calls_and_descriptions_from_snapshot,
3635
)
36+
from haystack_experimental.components.retrievers import ChatMessageRetriever
37+
from haystack_experimental.components.writers import ChatMessageWriter
3738

3839

3940
@pytest.fixture
@@ -50,6 +51,19 @@ def run(self, messages: list[ChatMessage], tools: Any) -> dict[str, list[ChatMes
5051
return {"replies": [ChatMessage.from_assistant("This is a mock response.")]}
5152

5253

54+
@component
55+
class MockChatGeneratorToolsResponse:
56+
@component.output_types(replies=list[ChatMessage])
57+
def run(self, messages: list[ChatMessage], tools: Any) -> dict[str, list[ChatMessage]]:
58+
return {
59+
"replies": [
60+
ChatMessage.from_assistant(
61+
tool_calls=[ToolCall(tool_name="addition_tool", arguments={"a": 2, "b": 3})]
62+
)
63+
]
64+
}
65+
66+
5367
@component
5468
class MockAgent:
5569
def __init__(self, system_prompt: Optional[str] = None):
@@ -257,6 +271,35 @@ def test_from_dict(self, tools, confirmation_strategies, monkeypatch):
257271

258272

259273
class TestAgentConfirmationStrategy:
274+
def test_get_tool_calls_and_descriptions_from_snapshot_no_mutation_of_snapshot(self, tools, tmp_path):
275+
agent = Agent(
276+
chat_generator=MockChatGeneratorToolsResponse(),
277+
tools=tools,
278+
confirmation_strategies={
279+
"addition_tool": BreakpointConfirmationStrategy(snapshot_file_path=str(tmp_path)),
280+
},
281+
)
282+
agent.warm_up()
283+
284+
# Run the agent to create a snapshot with a breakpoint
285+
try:
286+
agent.run([ChatMessage.from_user("What is 2+2?")])
287+
except BreakpointException:
288+
pass
289+
290+
# Load the latest snapshot from disk
291+
loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))
292+
293+
original_snapshot = copy.deepcopy(loaded_snapshot)
294+
295+
# Extract tool calls and descriptions
296+
_ = get_tool_calls_and_descriptions_from_snapshot(
297+
agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
298+
)
299+
300+
# Verify that the original snapshot has not been mutated
301+
assert loaded_snapshot == original_snapshot
302+
260303
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
261304
@pytest.mark.integration
262305
def test_run_blocking_confirmation_strategy_modify(self, tools):

0 commit comments

Comments
 (0)