77from collections .abc import Callable
88from datetime import datetime
99from pathlib import Path
10- from typing import TYPE_CHECKING , Any
10+ from typing import Any
1111
1212from networkx import MultiDiGraph
1313
1414from haystack import logging
1515from haystack .core .errors import PipelineInvalidPipelineSnapshotError
1616from haystack .core .pipeline .utils import _deepcopy_with_exceptions
17- from haystack .dataclasses import ChatMessage
18- from haystack .dataclasses .breakpoints import (
19- AgentBreakpoint ,
20- AgentSnapshot ,
21- Breakpoint ,
22- PipelineSnapshot ,
23- PipelineState ,
24- ToolBreakpoint ,
25- )
17+ from haystack .dataclasses .breakpoints import Breakpoint , PipelineSnapshot , PipelineState
2618from haystack .utils .base_serialization import _serialize_value_with_schema
27- from haystack .utils .misc import _get_output_dir
28-
29- if TYPE_CHECKING :
30- from haystack .components .agents .agent import _ExecutionContext
31- from haystack .tools import ToolsType
3219
3320logger = logging .getLogger (__name__ )
3421
@@ -54,34 +41,17 @@ def _is_snapshot_save_enabled() -> bool:
5441 return value in ("true" , "1" )
5542
5643
57- def _validate_break_point_against_pipeline (break_point : Breakpoint | AgentBreakpoint , graph : MultiDiGraph ) -> None :
44+ def _validate_break_point_against_pipeline (break_point : Breakpoint , graph : MultiDiGraph ) -> None :
5845 """
5946 Validates the breakpoints passed to the pipeline.
6047
6148 Makes sure the breakpoint contains a valid components registered in the pipeline.
6249
63- :param break_point: a breakpoint to validate, can be Breakpoint or AgentBreakpoint
50+ :param break_point: a breakpoint to validate
6451 """
65-
66- # all Breakpoints must refer to a valid component in the pipeline
67- if isinstance (break_point , Breakpoint ) and break_point .component_name not in graph .nodes :
52+ if break_point .component_name not in graph .nodes :
6853 raise ValueError (f"break_point { break_point } is not a registered component in the pipeline" )
6954
70- if isinstance (break_point , AgentBreakpoint ):
71- breakpoint_agent_component = graph .nodes .get (break_point .agent_name )
72- if not breakpoint_agent_component :
73- raise ValueError (f"break_point { break_point } is not a registered Agent component in the pipeline" )
74-
75- if isinstance (break_point .break_point , ToolBreakpoint ):
76- instance = breakpoint_agent_component ["instance" ]
77- for tool in instance .tools :
78- if break_point .break_point .tool_name == tool .name :
79- break
80- else :
81- raise ValueError (
82- f"break_point { break_point .break_point } is not a registered tool in the Agent component"
83- )
84-
8555
8656def _validate_pipeline_snapshot_against_pipeline (pipeline_snapshot : PipelineSnapshot , graph : MultiDiGraph ) -> None :
8757 """
@@ -121,11 +91,7 @@ def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnap
12191 f"are not part of the current pipeline."
12292 )
12393
124- if isinstance (pipeline_snapshot .break_point , AgentBreakpoint ):
125- component_name = pipeline_snapshot .break_point .agent_name
126- else :
127- component_name = pipeline_snapshot .break_point .component_name
128-
94+ component_name = pipeline_snapshot .break_point .component_name
12995 visit_count = pipeline_snapshot .pipeline_state .component_visits [component_name ]
13096
13197 logger .info (
@@ -216,30 +182,18 @@ def _save_pipeline_snapshot(
216182 return None
217183
218184 break_point = pipeline_snapshot .break_point
219- snapshot_file_path = (
220- break_point .break_point .snapshot_file_path
221- if isinstance (break_point , AgentBreakpoint )
222- else break_point .snapshot_file_path
223- )
185+ snapshot_file_path = break_point .snapshot_file_path
224186
225187 if snapshot_file_path is None :
226188 return None
227189
228190 dt = pipeline_snapshot .timestamp or datetime .now ()
229191 snapshot_dir = Path (snapshot_file_path )
230192
231- # Generate filename
232- # We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
233- if isinstance (break_point , AgentBreakpoint ):
234- agent_name = break_point .agent_name
235- component_name = break_point .break_point .component_name
236- else :
237- component_name = break_point .component_name
238- agent_name = None
239-
193+ component_name = break_point .component_name
240194 visit_nr = pipeline_snapshot .pipeline_state .component_visits .get (component_name , 0 )
241195 timestamp = dt .strftime ("%Y_%m_%d_%H_%M_%S" )
242- file_name = f"{ agent_name + '_' if agent_name else '' } { component_name } _{ visit_nr } _{ timestamp } .json"
196+ file_name = f"{ component_name } _{ visit_nr } _{ timestamp } .json"
243197 full_path = snapshot_dir / file_name
244198
245199 try :
@@ -262,7 +216,7 @@ def _create_pipeline_snapshot(
262216 * ,
263217 inputs : dict [str , Any ],
264218 component_inputs : dict [str , Any ],
265- break_point : AgentBreakpoint | Breakpoint ,
219+ break_point : Breakpoint ,
266220 component_visits : dict [str , int ],
267221 original_input_data : dict [str , Any ],
268222 ordered_component_names : list [str ],
@@ -274,7 +228,7 @@ def _create_pipeline_snapshot(
274228
275229 :param inputs: The current pipeline snapshot inputs.
276230 :param component_inputs: The inputs to the component that triggered the breakpoint.
277- :param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint .
231+ :param break_point: The breakpoint that triggered the snapshot.
278232 :param component_visits: The visit count of the component that triggered the breakpoint.
279233 :param original_input_data: The original input data.
280234 :param ordered_component_names: The ordered component names.
@@ -283,10 +237,7 @@ def _create_pipeline_snapshot(
283237 :returns:
284238 A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint.
285239 """
286- if isinstance (break_point , AgentBreakpoint ):
287- component_name = break_point .agent_name
288- else :
289- component_name = break_point .component_name
240+ component_name = break_point .component_name
290241
291242 transformed_original_input_data = _transform_json_structure (original_input_data )
292243 transformed_inputs = _transform_json_structure ({** inputs , component_name : component_inputs })
@@ -343,31 +294,6 @@ def _transform_json_structure(data: dict[str, Any] | list[Any] | Any) -> Any:
343294 return data
344295
345296
346- def _create_agent_snapshot (
347- * , component_visits : dict [str , int ], agent_breakpoint : AgentBreakpoint , component_inputs : dict [str , Any ]
348- ) -> AgentSnapshot :
349- """
350- Create a snapshot of the agent's state.
351-
352- :param component_visits: The visit counts for the agent's components.
353- :param agent_breakpoint: AgentBreakpoint object containing breakpoints
354- :return: An AgentSnapshot containing the agent's state and component visits.
355- """
356- serialized_chat_generator = _serialize_agent_component_inputs (
357- component_name = "chat_generator" , component_inputs = component_inputs ["chat_generator" ]
358- )
359- serialized_tool_invoker = _serialize_agent_component_inputs (
360- component_name = "tool_invoker" , component_inputs = component_inputs ["tool_invoker" ]
361- )
362-
363- return AgentSnapshot (
364- component_inputs = {"chat_generator" : serialized_chat_generator , "tool_invoker" : serialized_tool_invoker },
365- component_visits = component_visits ,
366- break_point = agent_breakpoint ,
367- timestamp = datetime .now (),
368- )
369-
370-
371297def _serialize_with_field_fallback (payload : Any , * , description : str ) -> dict [str , Any ]:
372298 """
373299 Serialize a payload and, on failure, retry field-by-field to preserve resumable fields.
@@ -417,149 +343,3 @@ def _serialize_with_field_fallback(payload: Any, *, description: str) -> dict[st
417343 "serialization_schema" : {"type" : "object" , "properties" : serialized_properties },
418344 "serialized_data" : serialized_data ,
419345 }
420-
421-
422- def _serialize_agent_component_inputs (component_name : str , component_inputs : dict [str , Any ]) -> dict [str , Any ]:
423- """
424- Serialize agent component inputs while preserving resumable fields whenever possible.
425-
426- Thin wrapper around :func:`_serialize_with_field_fallback` that supplies an agent-specific label
427- for the warning messages.
428-
429- :param component_name: Name of the agent sub-component (e.g. ``chat_generator`` or ``tool_invoker``).
430- :param component_inputs: Runtime inputs for that sub-component.
431- :returns: A serialized payload that is always a structurally valid ``{"serialization_schema",
432- "serialized_data"}`` pair. When every field fails to serialize, an empty-but-valid object
433- payload is returned so that ``_deserialize_value_with_schema`` can still load it (for example
434- when resuming from a ``ToolBreakpoint`` where the sub-component's inputs are not strictly required).
435- """
436- return _serialize_with_field_fallback (component_inputs , description = f"the agent's { component_name } inputs" )
437-
438-
439- def _validate_tool_breakpoint_is_valid (agent_breakpoint : AgentBreakpoint , tools : "ToolsType" ) -> None :
440- """
441- Validates the AgentBreakpoint passed to the agent.
442-
443- Validates that the tool name in ToolBreakpoints correspond to a tool available in the agent.
444-
445- :param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components.
446- :param tools: A list of Tool and/or Toolset objects, or a Toolset that the agent can use.
447- :raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools.
448- """
449- from haystack .tools .utils import flatten_tools_or_toolsets # avoid circular import
450-
451- available_tool_names = {tool .name for tool in flatten_tools_or_toolsets (tools )}
452- tool_breakpoint = agent_breakpoint .break_point
453- # Assert added for mypy to pass, but this is already checked before this function is called
454- assert isinstance (tool_breakpoint , ToolBreakpoint )
455- if tool_breakpoint .tool_name and tool_breakpoint .tool_name not in available_tool_names :
456- raise ValueError (f"Tool '{ tool_breakpoint .tool_name } ' is not available in the agent's tools" )
457-
458-
459- def _create_pipeline_snapshot_from_chat_generator (
460- * , execution_context : "_ExecutionContext" , agent_name : str | None = None , break_point : AgentBreakpoint | None = None
461- ) -> PipelineSnapshot :
462- """
463- Create a pipeline snapshot when a chat generator breakpoint is raised or an exception during execution occurs.
464-
465- :param execution_context: The current execution context of the agent.
466- :param agent_name: The name of the agent component if present in a pipeline.
467- :param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
468- A scenario where a new breakpoint is created is when an exception occurs during chat generation and we want to
469- capture the state at that point.
470- :returns:
471- A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
472- """
473- if break_point is None :
474- agent_breakpoint = AgentBreakpoint (
475- agent_name = agent_name or "agent" ,
476- break_point = Breakpoint (
477- component_name = "chat_generator" ,
478- visit_count = execution_context .component_visits ["chat_generator" ],
479- snapshot_file_path = _get_output_dir ("pipeline_snapshot" ),
480- ),
481- )
482- else :
483- agent_breakpoint = break_point
484-
485- agent_snapshot = _create_agent_snapshot (
486- component_visits = execution_context .component_visits ,
487- agent_breakpoint = agent_breakpoint ,
488- component_inputs = {
489- "chat_generator" : {
490- "messages" : execution_context .state .data ["messages" ],
491- ** execution_context .chat_generator_inputs ,
492- },
493- "tool_invoker" : {"messages" : [], "state" : execution_context .state , ** execution_context .tool_invoker_inputs },
494- },
495- )
496-
497- return PipelineSnapshot ._from_agent_snapshot (agent_snapshot = agent_snapshot )
498-
499-
500- def _create_pipeline_snapshot_from_tool_invoker (
501- * ,
502- execution_context : "_ExecutionContext" ,
503- tool_name : str | None = None ,
504- agent_name : str | None = None ,
505- break_point : AgentBreakpoint | None = None ,
506- ) -> PipelineSnapshot :
507- """
508- Create a pipeline snapshot when a tool invoker breakpoint is raised or an exception during execution occurs.
509-
510- :param execution_context: The current execution context of the agent.
511- :param tool_name: The name of the tool that triggered the breakpoint, if available.
512- :param agent_name: The name of the agent component if present in a pipeline.
513- :param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
514- A scenario where a new breakpoint is created is when an exception occurs during tool execution and we want to
515- capture the state at that point.
516- :returns:
517- A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
518- """
519- if break_point is None :
520- agent_breakpoint = AgentBreakpoint (
521- agent_name = agent_name or "agent" ,
522- break_point = ToolBreakpoint (
523- component_name = "tool_invoker" ,
524- visit_count = execution_context .component_visits ["tool_invoker" ],
525- tool_name = tool_name ,
526- snapshot_file_path = _get_output_dir ("pipeline_snapshot" ),
527- ),
528- )
529- else :
530- agent_breakpoint = break_point
531-
532- messages = execution_context .state .data ["messages" ]
533- agent_snapshot = _create_agent_snapshot (
534- component_visits = execution_context .component_visits ,
535- agent_breakpoint = agent_breakpoint ,
536- component_inputs = {
537- "chat_generator" : {"messages" : messages [:- 1 ], ** execution_context .chat_generator_inputs },
538- "tool_invoker" : {
539- "messages" : messages [- 1 :], # tool invoker consumes last msg from the chat_generator, contains tool call
540- "state" : execution_context .state ,
541- ** execution_context .tool_invoker_inputs ,
542- },
543- },
544- )
545-
546- # Create an empty pipeline snapshot
547- return PipelineSnapshot ._from_agent_snapshot (agent_snapshot = agent_snapshot )
548-
549-
550- def _should_trigger_tool_invoker_breakpoint (break_point : ToolBreakpoint , llm_messages : list [ChatMessage ]) -> bool :
551- """
552- Determine if a tool invoker breakpoint should be triggered based on the provided ToolBreakpoint and LLM messages.
553-
554- :param break_point: The ToolBreakpoint to check against.
555- :param llm_messages: A list of ChatMessage objects representing the LLM messages.
556- :returns:
557- True if the breakpoint should be triggered, False otherwise.
558- """
559- # Check if we should break for this specific tool or all tools
560- if break_point .tool_name is None :
561- # Break for any tool call
562- return any (msg .tool_call for msg in llm_messages )
563-
564- # Break only for the specific tool
565- return any (tc .tool_name == break_point .tool_name for msg in llm_messages for tc in msg .tool_calls or [])
0 commit comments