22#
33# SPDX-License-Identifier: Apache-2.0
44
5+ import copy
56import os
67from pathlib import Path
78from typing import Any , Optional
1213from haystack .components .generators .chat import OpenAIChatGenerator
1314from haystack .core .errors import BreakpointException
1415from haystack .core .pipeline .breakpoint import load_pipeline_snapshot
15- from haystack .dataclasses import ChatMessage
16+ from haystack .dataclasses import ChatMessage , ToolCall
1617from haystack .dataclasses .breakpoints import PipelineSnapshot
1718from haystack .tools import Tool , create_tool_from_function
1819
1920from haystack_experimental .chat_message_stores .in_memory import InMemoryChatMessageStore
20- from haystack_experimental .components .retrievers import ChatMessageRetriever
21- from haystack_experimental .components .writers import ChatMessageWriter
2221from haystack_experimental .components .agents .agent import Agent
2322from haystack_experimental .components .agents .human_in_the_loop import (
2423 AlwaysAskPolicy ,
3433from haystack_experimental .components .agents .human_in_the_loop .breakpoint import (
3534 get_tool_calls_and_descriptions_from_snapshot ,
3635)
36+ from haystack_experimental .components .retrievers import ChatMessageRetriever
37+ from haystack_experimental .components .writers import ChatMessageWriter
3738
3839
3940@pytest .fixture
@@ -50,6 +51,19 @@ def run(self, messages: list[ChatMessage], tools: Any) -> dict[str, list[ChatMes
5051 return {"replies" : [ChatMessage .from_assistant ("This is a mock response." )]}
5152
5253
54+ @component
55+ class MockChatGeneratorToolsResponse :
56+ @component .output_types (replies = list [ChatMessage ])
57+ def run (self , messages : list [ChatMessage ], tools : Any ) -> dict [str , list [ChatMessage ]]:
58+ return {
59+ "replies" : [
60+ ChatMessage .from_assistant (
61+ tool_calls = [ToolCall (tool_name = "addition_tool" , arguments = {"a" : 2 , "b" : 3 })]
62+ )
63+ ]
64+ }
65+
66+
5367@component
5468class MockAgent :
5569 def __init__ (self , system_prompt : Optional [str ] = None ):
@@ -257,6 +271,35 @@ def test_from_dict(self, tools, confirmation_strategies, monkeypatch):
257271
258272
259273class TestAgentConfirmationStrategy :
274+ def test_get_tool_calls_and_descriptions_from_snapshot_no_mutation_of_snapshot (self , tools , tmp_path ):
275+ agent = Agent (
276+ chat_generator = MockChatGeneratorToolsResponse (),
277+ tools = tools ,
278+ confirmation_strategies = {
279+ "addition_tool" : BreakpointConfirmationStrategy (snapshot_file_path = str (tmp_path )),
280+ },
281+ )
282+ agent .warm_up ()
283+
284+ # Run the agent to create a snapshot with a breakpoint
285+ try :
286+ agent .run ([ChatMessage .from_user ("What is 2+2?" )])
287+ except BreakpointException :
288+ pass
289+
290+ # Load the latest snapshot from disk
291+ loaded_snapshot = get_latest_snapshot (snapshot_file_path = str (tmp_path ))
292+
293+ original_snapshot = copy .deepcopy (loaded_snapshot )
294+
295+ # Extract tool calls and descriptions
296+ _ = get_tool_calls_and_descriptions_from_snapshot (
297+ agent_snapshot = loaded_snapshot .agent_snapshot , breakpoint_tool_only = True
298+ )
299+
300+ # Verify that the original snapshot has not been mutated
301+ assert loaded_snapshot == original_snapshot
302+
260303 @pytest .mark .skipif (not os .environ .get ("OPENAI_API_KEY" ), reason = "OPENAI_API_KEY not set" )
261304 @pytest .mark .integration
262305 def test_run_blocking_confirmation_strategy_modify (self , tools ):
0 commit comments